Unverified Commit 31dfff7d authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

use default for torch.ops (#4835)

parent 10a9ab7b
...@@ -12,49 +12,49 @@ if torch.version.hip is not None: ...@@ -12,49 +12,49 @@ if torch.version.hip is not None:
rank: int, rank: int,
full_nvlink: bool, full_nvlink: bool,
) -> int: ) -> int:
return torch.ops.sgl_kernel.init_custom_ar( return torch.ops.sgl_kernel.init_custom_ar.default(
meta, rank_data, handles, offsets, rank, full_nvlink meta, rank_data, handles, offsets, rank, full_nvlink
) )
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out) torch.ops.sgl_kernel.all_reduce_reg.default(fa, inp, out)
def all_reduce_unreg( def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None: ) -> None:
torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out) torch.ops.sgl_kernel.all_reduce_unreg.default(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None: def dispose(fa: int) -> None:
torch.ops.sgl_kernel.dispose(fa) torch.ops.sgl_kernel.dispose.default(fa)
def meta_size() -> int: def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size() return torch.ops.sgl_kernel.meta_size.default()
def register_buffer( def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None: ) -> None:
return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets) return torch.ops.sgl_kernel.register_buffer.default(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa) return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
def register_graph_buffers( def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]] fa: int, handles: List[str], offsets: List[List[int]]
) -> None: ) -> None:
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets) torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor: def allocate_meta_buffer(size: int) -> torch.Tensor:
return torch.ops.sgl_kernel.allocate_meta_buffer(size) return torch.ops.sgl_kernel.allocate_meta_buffer.default(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp) return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
else: else:
# TRTLLM custom allreduce # TRTLLM custom allreduce
def init_custom_reduce( def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
): ):
return torch.ops.sgl_kernel.init_custom_ar( return torch.ops.sgl_kernel.init_custom_ar.default(
rank_id, rank_id,
num_devices, num_devices,
rank_data, rank_data,
...@@ -65,13 +65,13 @@ else: ...@@ -65,13 +65,13 @@ else:
) )
def custom_dispose(fa): def custom_dispose(fa):
torch.ops.sgl_kernel.dispose(fa) torch.ops.sgl_kernel.dispose.default(fa)
def custom_reduce(fa, inp, out): def custom_reduce(fa, inp, out):
torch.ops.sgl_kernel.all_reduce(fa, inp, out) torch.ops.sgl_kernel.all_reduce.default(fa, inp, out)
def get_graph_buffer_ipc_meta(fa): def get_graph_buffer_ipc_meta(fa):
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa) return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
def register_graph_buffers(fa, handles, offsets): def register_graph_buffers(fa, handles, offsets):
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets) torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
...@@ -2,6 +2,6 @@ import torch ...@@ -2,6 +2,6 @@ import torch
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch.ops.sgl_kernel.lightning_attention_decode( torch.ops.sgl_kernel.lightning_attention_decode.default(
q, k, v, past_kv, slope, output, new_kv q, k, v, past_kv, slope, output, new_kv
) )
...@@ -14,14 +14,14 @@ def rmsnorm( ...@@ -14,14 +14,14 @@ def rmsnorm(
) -> torch.Tensor: ) -> torch.Tensor:
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream()) torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
return out return out
def fused_add_rmsnorm( def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None: ) -> None:
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps) torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
def gemma_rmsnorm( def gemma_rmsnorm(
...@@ -32,14 +32,16 @@ def gemma_rmsnorm( ...@@ -32,14 +32,16 @@ def gemma_rmsnorm(
) -> torch.Tensor: ) -> torch.Tensor:
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream()) torch.ops.sgl_kernel.gemma_rmsnorm.default(
out, input, weight, eps, get_cuda_stream()
)
return out return out
def gemma_fused_add_rmsnorm( def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None: ) -> None:
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm( torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, get_cuda_stream() input, residual, weight, eps, get_cuda_stream()
) )
...@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: ...@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream()) torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
return out return out
...@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te ...@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream()) torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
return out return out
...@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: ...@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device=input.device, device=input.device,
dtype=input.dtype, dtype=input.dtype,
) )
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream()) torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
return out return out
...@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32: if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32") raise ValueError("cos_sin_cache should be float32")
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache( torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
q=query.view(query.shape[0], -1, head_size), q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size),
......
...@@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream ...@@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def awq_dequantize( def awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.ByteTensor: ) -> torch.ByteTensor:
return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros) return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.int8_scaled_mm( return torch.ops.sgl_kernel.int8_scaled_mm.default(
mat_a, mat_a,
mat_b, mat_b,
scales_a, scales_a,
...@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ...@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm( return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
mat_a, mat_a,
mat_b, mat_b,
scales_a, scales_a,
...@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): ...@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.fp8_scaled_mm( return torch.ops.sgl_kernel.fp8_scaled_mm.default(
mat_a, mat_a,
mat_b, mat_b,
scales_a, scales_a,
...@@ -51,7 +51,7 @@ def _bmm_fp8_internal( ...@@ -51,7 +51,7 @@ def _bmm_fp8_internal(
B_scale: torch.Tensor, B_scale: torch.Tensor,
) -> None: ) -> None:
cublas_handle = torch.cuda.current_blas_handle() cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernel.bmm_fp8( torch.ops.sgl_kernel.bmm_fp8.default(
A, A,
B, B,
D, D,
...@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8( ...@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8(
fp8_min: float, fp8_min: float,
fp8_max: float, fp8_max: float,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8( torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max input, output_q, output_s, group_size, eps, fp8_min, fp8_max
) )
...@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8( ...@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8(
int8_min: float, int8_min: float,
int8_max: float, int8_max: float,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8( torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max input, output_q, output_s, group_size, eps, int8_min, int8_max
) )
...@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8( ...@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8(
output_s: torch.Tensor, output_s: torch.Tensor,
is_static: bool, is_static: bool,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static) torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
input, output_q, output_s, is_static
)
def cublas_grouped_gemm( def cublas_grouped_gemm(
...@@ -129,7 +131,7 @@ def cublas_grouped_gemm( ...@@ -129,7 +131,7 @@ def cublas_grouped_gemm(
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0 len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
), "Inputs/weights/outputs should not be empty!" ), "Inputs/weights/outputs should not be empty!"
cublas_handle = torch.cuda.current_blas_handle() cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernel.cublas_grouped_gemm( torch.ops.sgl_kernel.cublas_grouped_gemm.default(
inputs, inputs,
weights, weights,
outputs, outputs,
...@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8( ...@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8(
output_q: torch.Tensor, output_q: torch.Tensor,
output_s: torch.Tensor, output_s: torch.Tensor,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s) torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
def cutlass_scaled_fp4_mm( def cutlass_scaled_fp4_mm(
...@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm( ...@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm(
assert a.ndim == 2 and b.ndim == 2 assert a.ndim == 2 and b.ndim == 2
m, n = a.shape[0], b.shape[0] m, n = a.shape[0], b.shape[0]
out = torch.empty((m, n), dtype=out_dtype, device=a.device) out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops.sgl_kernels.cutlass_scaled_fp4_mm( torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(
out, a, b, block_scale_a, block_scale_b, alpha out, a, b, block_scale_a, block_scale_b, alpha
) )
return out return out
...@@ -210,7 +212,7 @@ def scaled_fp4_quant( ...@@ -210,7 +212,7 @@ def scaled_fp4_quant(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32 (rounded_m, rounded_n // 4), device=device, dtype=torch.int32
) )
torch.ops.sgl_kernels.scaled_fp4_quant( torch.ops.sgl_kernel.scaled_fp4_quant.default(
output, input, output_scale, input_global_scale output, input, output_scale, input_global_scale
) )
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)
......
...@@ -11,7 +11,7 @@ def moe_align_block_size( ...@@ -11,7 +11,7 @@ def moe_align_block_size(
token_cnts_buffer, token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
): ):
torch.ops.sgl_kernel.moe_align_block_size( torch.ops.sgl_kernel.moe_align_block_size.default(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
...@@ -29,6 +29,6 @@ def topk_softmax( ...@@ -29,6 +29,6 @@ def topk_softmax(
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: float, gating_output: float,
) -> None: ) -> None:
torch.ops.sgl_kernel.topk_softmax( torch.ops.sgl_kernel.topk_softmax.default(
topk_weights, topk_ids, token_expert_indices, gating_output topk_weights, topk_ids, token_expert_indices, gating_output
) )
...@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal( ...@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
probs = probs.float() probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs) renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_k_renorm_probs( torch.ops.sgl_kernel.top_k_renorm_probs.default(
probs, probs,
renorm_probs, renorm_probs,
maybe_top_k_arr, maybe_top_k_arr,
...@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal( ...@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal(
probs = probs.float() probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs) renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_p_renorm_probs( torch.ops.sgl_kernel.top_p_renorm_probs.default(
probs, probs,
renorm_probs, renorm_probs,
maybe_top_p_arr, maybe_top_p_arr,
...@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal( ...@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernel.top_p_sampling_from_probs( torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
...@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal( ...@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs( torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
...@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal( ...@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.min_p_sampling_from_probs( torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
......
...@@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only( ...@@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only(
threshold_acc: float = 1.0, threshold_acc: float = 1.0,
deterministic: bool = True, deterministic: bool = True,
) -> None: ) -> None:
torch.ops.sgl_kernel.tree_speculative_sampling_target_only( torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
predicts, predicts,
accept_index, accept_index,
accept_token_num, accept_token_num,
...@@ -45,7 +45,7 @@ def verify_tree_greedy( ...@@ -45,7 +45,7 @@ def verify_tree_greedy(
retrive_next_sibling: torch.Tensor, retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor, target_predict: torch.Tensor,
) -> None: ) -> None:
torch.ops.sgl_kernel.verify_tree_greedy( torch.ops.sgl_kernel.verify_tree_greedy.default(
predicts, predicts,
accept_index, accept_index,
accept_token_num, accept_token_num,
...@@ -71,7 +71,7 @@ def build_tree_kernel_efficient( ...@@ -71,7 +71,7 @@ def build_tree_kernel_efficient(
depth: int, depth: int,
draft_token_num: int, draft_token_num: int,
) -> None: ) -> None:
torch.ops.sgl_kernel.build_tree_kernel_efficient( torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
parent_list, parent_list,
selected_index, selected_index,
verified_seq_len, verified_seq_len,
...@@ -92,7 +92,7 @@ def segment_packbits( ...@@ -92,7 +92,7 @@ def segment_packbits(
output_indptr: torch.Tensor, output_indptr: torch.Tensor,
y: torch.Tensor, y: torch.Tensor,
) -> None: ) -> None:
torch.ops.sgl_kernel.segment_packbits( torch.ops.sgl_kernel.segment_packbits.default(
x, x,
input_indptr, input_indptr,
output_indptr, output_indptr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment