Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
31dfff7d
Unverified
Commit
31dfff7d
authored
Mar 27, 2025
by
Yineng Zhang
Committed by
GitHub
Mar 27, 2025
Browse files
use default for torch.ops (#4835)
parent
10a9ab7b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
47 deletions
+51
-47
sgl-kernel/python/sgl_kernel/allreduce.py
sgl-kernel/python/sgl_kernel/allreduce.py
+15
-15
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+1
-1
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+10
-8
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+14
-12
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+2
-2
sgl-kernel/python/sgl_kernel/sampling.py
sgl-kernel/python/sgl_kernel/sampling.py
+5
-5
sgl-kernel/python/sgl_kernel/speculative.py
sgl-kernel/python/sgl_kernel/speculative.py
+4
-4
No files found.
sgl-kernel/python/sgl_kernel/allreduce.py
View file @
31dfff7d
...
...
@@ -12,49 +12,49 @@ if torch.version.hip is not None:
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
(
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
.
default
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernel
.
all_reduce_reg
(
fa
,
inp
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce_reg
.
default
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernel
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce_unreg
.
default
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernel
.
dispose
(
fa
)
torch
.
ops
.
sgl_kernel
.
dispose
.
default
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernel
.
meta_size
()
return
torch
.
ops
.
sgl_kernel
.
meta_size
.
default
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
torch
.
ops
.
sgl_kernel
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
return
torch
.
ops
.
sgl_kernel
.
register_buffer
.
default
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
(
fa
)
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
.
default
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
.
default
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
allocate_meta_buffer
(
size
)
return
torch
.
ops
.
sgl_kernel
.
allocate_meta_buffer
.
default
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
(
inp
)
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
.
default
(
inp
)
else
:
# TRTLLM custom allreduce
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
(
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
.
default
(
rank_id
,
num_devices
,
rank_data
,
...
...
@@ -65,13 +65,13 @@ else:
)
def
custom_dispose
(
fa
):
torch
.
ops
.
sgl_kernel
.
dispose
(
fa
)
torch
.
ops
.
sgl_kernel
.
dispose
.
default
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
torch
.
ops
.
sgl_kernel
.
all_reduce
(
fa
,
inp
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce
.
default
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
(
fa
)
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
.
default
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
.
default
(
fa
,
handles
,
offsets
)
sgl-kernel/python/sgl_kernel/attention.py
View file @
31dfff7d
...
...
@@ -2,6 +2,6 @@ import torch
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
torch
.
ops
.
sgl_kernel
.
lightning_attention_decode
(
torch
.
ops
.
sgl_kernel
.
lightning_attention_decode
.
default
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
31dfff7d
...
...
@@ -14,14 +14,14 @@ def rmsnorm(
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernel
.
rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
return
out
def
fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
torch
.
ops
.
sgl_kernel
.
fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
)
torch
.
ops
.
sgl_kernel
.
fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
)
def
gemma_rmsnorm
(
...
...
@@ -32,14 +32,16 @@ def gemma_rmsnorm(
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernel
.
gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gemma_rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
()
)
return
out
def
gemma_fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
torch
.
ops
.
sgl_kernel
.
gemma_fused_add_rmsnorm
(
torch
.
ops
.
sgl_kernel
.
gemma_fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
()
)
...
...
@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
silu_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
silu_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
...
...
@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
gelu_tanh_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gelu_tanh_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
...
...
@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
gelu_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gelu_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
...
...
@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
(
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
.
default
(
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
31dfff7d
...
...
@@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
ByteTensor
:
return
torch
.
ops
.
sgl_kernel
.
awq_dequantize
(
qweight
,
scales
,
qzeros
)
return
torch
.
ops
.
sgl_kernel
.
awq_dequantize
.
default
(
qweight
,
scales
,
qzeros
)
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm
.
default
(
mat_a
,
mat_b
,
scales_a
,
...
...
@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
return
torch
.
ops
.
sgl_kernel
.
fp8_blockwise_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
fp8_blockwise_scaled_mm
.
default
(
mat_a
,
mat_b
,
scales_a
,
...
...
@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm
.
default
(
mat_a
,
mat_b
,
scales_a
,
...
...
@@ -51,7 +51,7 @@ def _bmm_fp8_internal(
B_scale
:
torch
.
Tensor
,
)
->
None
:
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernel
.
bmm_fp8
(
torch
.
ops
.
sgl_kernel
.
bmm_fp8
.
default
(
A
,
B
,
D
,
...
...
@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8(
fp8_min
:
float
,
fp8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
...
...
@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8(
int8_min
:
float
,
int8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
...
...
@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8(
output_s
:
torch
.
Tensor
,
is_static
:
bool
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_tensor_quant_fp8
(
input
,
output_q
,
output_s
,
is_static
)
torch
.
ops
.
sgl_kernel
.
sgl_per_tensor_quant_fp8
.
default
(
input
,
output_q
,
output_s
,
is_static
)
def
cublas_grouped_gemm
(
...
...
@@ -129,7 +131,7 @@ def cublas_grouped_gemm(
len
(
inputs
)
>
0
and
len
(
weights
)
>
0
and
len
(
outputs
)
>
0
),
"Inputs/weights/outputs should not be empty!"
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernel
.
cublas_grouped_gemm
(
torch
.
ops
.
sgl_kernel
.
cublas_grouped_gemm
.
default
(
inputs
,
weights
,
outputs
,
...
...
@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8(
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_quant_fp8
(
input
,
output_q
,
output_s
)
torch
.
ops
.
sgl_kernel
.
sgl_per_token_quant_fp8
.
default
(
input
,
output_q
,
output_s
)
def
cutlass_scaled_fp4_mm
(
...
...
@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm(
assert
a
.
ndim
==
2
and
b
.
ndim
==
2
m
,
n
=
a
.
shape
[
0
],
b
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
sgl_kernel
s
.
cutlass_scaled_fp4_mm
(
torch
.
ops
.
sgl_kernel
.
cutlass_scaled_fp4_mm
.
default
(
out
,
a
,
b
,
block_scale_a
,
block_scale_b
,
alpha
)
return
out
...
...
@@ -210,7 +212,7 @@ def scaled_fp4_quant(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
sgl_kernel
s
.
scaled_fp4_quant
(
torch
.
ops
.
sgl_kernel
.
scaled_fp4_quant
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
31dfff7d
...
...
@@ -11,7 +11,7 @@ def moe_align_block_size(
token_cnts_buffer
,
cumsum_buffer
,
):
torch
.
ops
.
sgl_kernel
.
moe_align_block_size
(
torch
.
ops
.
sgl_kernel
.
moe_align_block_size
.
default
(
topk_ids
,
num_experts
,
block_size
,
...
...
@@ -29,6 +29,6 @@ def topk_softmax(
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
topk_softmax
(
torch
.
ops
.
sgl_kernel
.
topk_softmax
.
default
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
)
sgl-kernel/python/sgl_kernel/sampling.py
View file @
31dfff7d
...
...
@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernel
.
top_k_renorm_probs
(
torch
.
ops
.
sgl_kernel
.
top_k_renorm_probs
.
default
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
...
...
@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal(
probs
=
probs
.
float
()
maybe_top_p_arr
=
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernel
.
top_p_renorm_probs
(
torch
.
ops
.
sgl_kernel
.
top_p_renorm_probs
.
default
(
probs
,
renorm_probs
,
maybe_top_p_arr
,
...
...
@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
top_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
top_p_sampling_from_probs
.
default
(
probs
,
uniform_samples
,
samples
,
...
...
@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
top_k_top_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
top_k_top_p_sampling_from_probs
.
default
(
probs
,
uniform_samples
,
samples
,
...
...
@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
min_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
min_p_sampling_from_probs
.
default
(
probs
,
uniform_samples
,
samples
,
...
...
sgl-kernel/python/sgl_kernel/speculative.py
View file @
31dfff7d
...
...
@@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only(
threshold_acc
:
float
=
1.0
,
deterministic
:
bool
=
True
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
tree_speculative_sampling_target_only
(
torch
.
ops
.
sgl_kernel
.
tree_speculative_sampling_target_only
.
default
(
predicts
,
accept_index
,
accept_token_num
,
...
...
@@ -45,7 +45,7 @@ def verify_tree_greedy(
retrive_next_sibling
:
torch
.
Tensor
,
target_predict
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
verify_tree_greedy
(
torch
.
ops
.
sgl_kernel
.
verify_tree_greedy
.
default
(
predicts
,
accept_index
,
accept_token_num
,
...
...
@@ -71,7 +71,7 @@ def build_tree_kernel_efficient(
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
build_tree_kernel_efficient
(
torch
.
ops
.
sgl_kernel
.
build_tree_kernel_efficient
.
default
(
parent_list
,
selected_index
,
verified_seq_len
,
...
...
@@ -92,7 +92,7 @@ def segment_packbits(
output_indptr
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
segment_packbits
(
torch
.
ops
.
sgl_kernel
.
segment_packbits
.
default
(
x
,
input_indptr
,
output_indptr
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment