Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
31dfff7d
"vscode:/vscode.git/clone" did not exist on "2cfffea02eb2b9118848955177c2557d0c61ddbe"
Unverified
Commit
31dfff7d
authored
Mar 27, 2025
by
Yineng Zhang
Committed by
GitHub
Mar 27, 2025
Browse files
use default for torch.ops (#4835)
parent
10a9ab7b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
47 deletions
+51
-47
sgl-kernel/python/sgl_kernel/allreduce.py
sgl-kernel/python/sgl_kernel/allreduce.py
+15
-15
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+1
-1
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+10
-8
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+14
-12
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+2
-2
sgl-kernel/python/sgl_kernel/sampling.py
sgl-kernel/python/sgl_kernel/sampling.py
+5
-5
sgl-kernel/python/sgl_kernel/speculative.py
sgl-kernel/python/sgl_kernel/speculative.py
+4
-4
No files found.
sgl-kernel/python/sgl_kernel/allreduce.py
View file @
31dfff7d
...
...
@@ -12,49 +12,49 @@ if torch.version.hip is not None:
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
(
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
.
default
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernel
.
all_reduce_reg
(
fa
,
inp
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce_reg
.
default
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernel
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce_unreg
.
default
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernel
.
dispose
(
fa
)
torch
.
ops
.
sgl_kernel
.
dispose
.
default
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernel
.
meta_size
()
return
torch
.
ops
.
sgl_kernel
.
meta_size
.
default
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
torch
.
ops
.
sgl_kernel
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
return
torch
.
ops
.
sgl_kernel
.
register_buffer
.
default
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
(
fa
)
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
.
default
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
.
default
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
allocate_meta_buffer
(
size
)
return
torch
.
ops
.
sgl_kernel
.
allocate_meta_buffer
.
default
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
(
inp
)
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
.
default
(
inp
)
else
:
# TRTLLM custom allreduce
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
(
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
.
default
(
rank_id
,
num_devices
,
rank_data
,
...
...
@@ -65,13 +65,13 @@ else:
)
def
custom_dispose
(
fa
):
torch
.
ops
.
sgl_kernel
.
dispose
(
fa
)
torch
.
ops
.
sgl_kernel
.
dispose
.
default
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
torch
.
ops
.
sgl_kernel
.
all_reduce
(
fa
,
inp
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce
.
default
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
(
fa
)
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
.
default
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
.
default
(
fa
,
handles
,
offsets
)
sgl-kernel/python/sgl_kernel/attention.py
View file @
31dfff7d
...
...
@@ -2,6 +2,6 @@ import torch
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
torch
.
ops
.
sgl_kernel
.
lightning_attention_decode
(
torch
.
ops
.
sgl_kernel
.
lightning_attention_decode
.
default
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
31dfff7d
...
...
@@ -14,14 +14,14 @@ def rmsnorm(
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernel
.
rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
return
out
def
fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
torch
.
ops
.
sgl_kernel
.
fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
)
torch
.
ops
.
sgl_kernel
.
fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
)
def
gemma_rmsnorm
(
...
...
@@ -32,14 +32,16 @@ def gemma_rmsnorm(
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernel
.
gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gemma_rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
()
)
return
out
def
gemma_fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
torch
.
ops
.
sgl_kernel
.
gemma_fused_add_rmsnorm
(
torch
.
ops
.
sgl_kernel
.
gemma_fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
()
)
...
...
@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
silu_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
silu_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
...
...
@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
gelu_tanh_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gelu_tanh_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
...
...
@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
gelu_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gelu_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
...
...
@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
(
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
.
default
(
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
31dfff7d
...
...
@@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
ByteTensor
:
return
torch
.
ops
.
sgl_kernel
.
awq_dequantize
(
qweight
,
scales
,
qzeros
)
return
torch
.
ops
.
sgl_kernel
.
awq_dequantize
.
default
(
qweight
,
scales
,
qzeros
)
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm
.
default
(
mat_a
,
mat_b
,
scales_a
,
...
...
@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
return
torch
.
ops
.
sgl_kernel
.
fp8_blockwise_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
fp8_blockwise_scaled_mm
.
default
(
mat_a
,
mat_b
,
scales_a
,
...
...
@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm
.
default
(
mat_a
,
mat_b
,
scales_a
,
...
...
@@ -51,7 +51,7 @@ def _bmm_fp8_internal(
B_scale
:
torch
.
Tensor
,
)
->
None
:
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernel
.
bmm_fp8
(
torch
.
ops
.
sgl_kernel
.
bmm_fp8
.
default
(
A
,
B
,
D
,
...
...
@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8(
fp8_min
:
float
,
fp8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
...
...
@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8(
int8_min
:
float
,
int8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
...
...
@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8(
output_s
:
torch
.
Tensor
,
is_static
:
bool
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_tensor_quant_fp8
(
input
,
output_q
,
output_s
,
is_static
)
torch
.
ops
.
sgl_kernel
.
sgl_per_tensor_quant_fp8
.
default
(
input
,
output_q
,
output_s
,
is_static
)
def
cublas_grouped_gemm
(
...
...
@@ -129,7 +131,7 @@ def cublas_grouped_gemm(
len
(
inputs
)
>
0
and
len
(
weights
)
>
0
and
len
(
outputs
)
>
0
),
"Inputs/weights/outputs should not be empty!"
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernel
.
cublas_grouped_gemm
(
torch
.
ops
.
sgl_kernel
.
cublas_grouped_gemm
.
default
(
inputs
,
weights
,
outputs
,
...
...
@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8(
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_quant_fp8
(
input
,
output_q
,
output_s
)
torch
.
ops
.
sgl_kernel
.
sgl_per_token_quant_fp8
.
default
(
input
,
output_q
,
output_s
)
def
cutlass_scaled_fp4_mm
(
...
...
@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm(
assert
a
.
ndim
==
2
and
b
.
ndim
==
2
m
,
n
=
a
.
shape
[
0
],
b
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
sgl_kernel
s
.
cutlass_scaled_fp4_mm
(
torch
.
ops
.
sgl_kernel
.
cutlass_scaled_fp4_mm
.
default
(
out
,
a
,
b
,
block_scale_a
,
block_scale_b
,
alpha
)
return
out
...
...
@@ -210,7 +212,7 @@ def scaled_fp4_quant(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
sgl_kernel
s
.
scaled_fp4_quant
(
torch
.
ops
.
sgl_kernel
.
scaled_fp4_quant
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
31dfff7d
...
...
@@ -11,7 +11,7 @@ def moe_align_block_size(
token_cnts_buffer
,
cumsum_buffer
,
):
torch
.
ops
.
sgl_kernel
.
moe_align_block_size
(
torch
.
ops
.
sgl_kernel
.
moe_align_block_size
.
default
(
topk_ids
,
num_experts
,
block_size
,
...
...
@@ -29,6 +29,6 @@ def topk_softmax(
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
topk_softmax
(
torch
.
ops
.
sgl_kernel
.
topk_softmax
.
default
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
)
sgl-kernel/python/sgl_kernel/sampling.py
View file @
31dfff7d
...
...
@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernel
.
top_k_renorm_probs
(
torch
.
ops
.
sgl_kernel
.
top_k_renorm_probs
.
default
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
...
...
@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal(
probs
=
probs
.
float
()
maybe_top_p_arr
=
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernel
.
top_p_renorm_probs
(
torch
.
ops
.
sgl_kernel
.
top_p_renorm_probs
.
default
(
probs
,
renorm_probs
,
maybe_top_p_arr
,
...
...
@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
top_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
top_p_sampling_from_probs
.
default
(
probs
,
uniform_samples
,
samples
,
...
...
@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
top_k_top_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
top_k_top_p_sampling_from_probs
.
default
(
probs
,
uniform_samples
,
samples
,
...
...
@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
min_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
min_p_sampling_from_probs
.
default
(
probs
,
uniform_samples
,
samples
,
...
...
sgl-kernel/python/sgl_kernel/speculative.py
View file @
31dfff7d
...
...
@@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only(
threshold_acc
:
float
=
1.0
,
deterministic
:
bool
=
True
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
tree_speculative_sampling_target_only
(
torch
.
ops
.
sgl_kernel
.
tree_speculative_sampling_target_only
.
default
(
predicts
,
accept_index
,
accept_token_num
,
...
...
@@ -45,7 +45,7 @@ def verify_tree_greedy(
retrive_next_sibling
:
torch
.
Tensor
,
target_predict
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
verify_tree_greedy
(
torch
.
ops
.
sgl_kernel
.
verify_tree_greedy
.
default
(
predicts
,
accept_index
,
accept_token_num
,
...
...
@@ -71,7 +71,7 @@ def build_tree_kernel_efficient(
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
build_tree_kernel_efficient
(
torch
.
ops
.
sgl_kernel
.
build_tree_kernel_efficient
.
default
(
parent_list
,
selected_index
,
verified_seq_len
,
...
...
@@ -92,7 +92,7 @@ def segment_packbits(
output_indptr
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
segment_packbits
(
torch
.
ops
.
sgl_kernel
.
segment_packbits
.
default
(
x
,
input_indptr
,
output_indptr
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment