Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
31dfff7d
Unverified
Commit
31dfff7d
authored
Mar 27, 2025
by
Yineng Zhang
Committed by
GitHub
Mar 27, 2025
Browse files
use default for torch.ops (#4835)
parent
10a9ab7b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
47 deletions
+51
-47
sgl-kernel/python/sgl_kernel/allreduce.py
sgl-kernel/python/sgl_kernel/allreduce.py
+15
-15
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+1
-1
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+10
-8
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+14
-12
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+2
-2
sgl-kernel/python/sgl_kernel/sampling.py
sgl-kernel/python/sgl_kernel/sampling.py
+5
-5
sgl-kernel/python/sgl_kernel/speculative.py
sgl-kernel/python/sgl_kernel/speculative.py
+4
-4
No files found.
sgl-kernel/python/sgl_kernel/allreduce.py
View file @
31dfff7d
...
@@ -12,49 +12,49 @@ if torch.version.hip is not None:
...
@@ -12,49 +12,49 @@ if torch.version.hip is not None:
rank
:
int
,
rank
:
int
,
full_nvlink
:
bool
,
full_nvlink
:
bool
,
)
->
int
:
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
(
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
.
default
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernel
.
all_reduce_reg
(
fa
,
inp
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce_reg
.
default
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce_unreg
.
default
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernel
.
dispose
(
fa
)
torch
.
ops
.
sgl_kernel
.
dispose
.
default
(
fa
)
def
meta_size
()
->
int
:
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernel
.
meta_size
()
return
torch
.
ops
.
sgl_kernel
.
meta_size
.
default
()
def
register_buffer
(
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
)
->
None
:
return
torch
.
ops
.
sgl_kernel
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
return
torch
.
ops
.
sgl_kernel
.
register_buffer
.
default
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
(
fa
)
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
.
default
(
fa
)
def
register_graph_buffers
(
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
.
default
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
allocate_meta_buffer
(
size
)
return
torch
.
ops
.
sgl_kernel
.
allocate_meta_buffer
.
default
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
(
inp
)
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
.
default
(
inp
)
else
:
else
:
# TRTLLM custom allreduce
# TRTLLM custom allreduce
def
init_custom_reduce
(
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
):
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
(
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
.
default
(
rank_id
,
rank_id
,
num_devices
,
num_devices
,
rank_data
,
rank_data
,
...
@@ -65,13 +65,13 @@ else:
...
@@ -65,13 +65,13 @@ else:
)
)
def
custom_dispose
(
fa
):
def
custom_dispose
(
fa
):
torch
.
ops
.
sgl_kernel
.
dispose
(
fa
)
torch
.
ops
.
sgl_kernel
.
dispose
.
default
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
def
custom_reduce
(
fa
,
inp
,
out
):
torch
.
ops
.
sgl_kernel
.
all_reduce
(
fa
,
inp
,
out
)
torch
.
ops
.
sgl_kernel
.
all_reduce
.
default
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
(
fa
)
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
.
default
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
.
default
(
fa
,
handles
,
offsets
)
sgl-kernel/python/sgl_kernel/attention.py
View file @
31dfff7d
...
@@ -2,6 +2,6 @@ import torch
...
@@ -2,6 +2,6 @@ import torch
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
torch
.
ops
.
sgl_kernel
.
lightning_attention_decode
(
torch
.
ops
.
sgl_kernel
.
lightning_attention_decode
.
default
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
)
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
31dfff7d
...
@@ -14,14 +14,14 @@ def rmsnorm(
...
@@ -14,14 +14,14 @@ def rmsnorm(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernel
.
rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
return
out
return
out
def
fused_add_rmsnorm
(
def
fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
)
torch
.
ops
.
sgl_kernel
.
fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
)
def
gemma_rmsnorm
(
def
gemma_rmsnorm
(
...
@@ -32,14 +32,16 @@ def gemma_rmsnorm(
...
@@ -32,14 +32,16 @@ def gemma_rmsnorm(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernel
.
gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gemma_rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
()
)
return
out
return
out
def
gemma_fused_add_rmsnorm
(
def
gemma_fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
gemma_fused_add_rmsnorm
(
torch
.
ops
.
sgl_kernel
.
gemma_fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
()
input
,
residual
,
weight
,
eps
,
get_cuda_stream
()
)
)
...
@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
...
@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device
=
input
.
device
,
device
=
input
.
device
,
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
)
)
torch
.
ops
.
sgl_kernel
.
silu_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
silu_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
return
out
...
@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
...
@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
device
=
input
.
device
,
device
=
input
.
device
,
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
)
)
torch
.
ops
.
sgl_kernel
.
gelu_tanh_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gelu_tanh_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
return
out
...
@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
...
@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
device
=
input
.
device
,
device
=
input
.
device
,
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
)
)
torch
.
ops
.
sgl_kernel
.
gelu_and_mul
(
out
,
input
,
get_cuda_stream
())
torch
.
ops
.
sgl_kernel
.
gelu_and_mul
.
default
(
out
,
input
,
get_cuda_stream
())
return
out
return
out
...
@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
...
@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
raise
ValueError
(
"cos_sin_cache should be float32"
)
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
(
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
.
default
(
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
31dfff7d
...
@@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
...
@@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def
awq_dequantize
(
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
ByteTensor
:
)
->
torch
.
ByteTensor
:
return
torch
.
ops
.
sgl_kernel
.
awq_dequantize
(
qweight
,
scales
,
qzeros
)
return
torch
.
ops
.
sgl_kernel
.
awq_dequantize
.
default
(
qweight
,
scales
,
qzeros
)
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm
.
default
(
mat_a
,
mat_a
,
mat_b
,
mat_b
,
scales_a
,
scales_a
,
...
@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
...
@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
return
torch
.
ops
.
sgl_kernel
.
fp8_blockwise_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
fp8_blockwise_scaled_mm
.
default
(
mat_a
,
mat_a
,
mat_b
,
mat_b
,
scales_a
,
scales_a
,
...
@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
...
@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm
.
default
(
mat_a
,
mat_a
,
mat_b
,
mat_b
,
scales_a
,
scales_a
,
...
@@ -51,7 +51,7 @@ def _bmm_fp8_internal(
...
@@ -51,7 +51,7 @@ def _bmm_fp8_internal(
B_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernel
.
bmm_fp8
(
torch
.
ops
.
sgl_kernel
.
bmm_fp8
.
default
(
A
,
A
,
B
,
B
,
D
,
D
,
...
@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8(
...
@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8(
fp8_min
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
fp8_max
:
float
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
)
...
@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8(
...
@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8(
int8_min
:
float
,
int8_min
:
float
,
int8_max
:
float
,
int8_max
:
float
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
)
...
@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8(
...
@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8(
output_s
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
is_static
:
bool
,
is_static
:
bool
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_tensor_quant_fp8
(
input
,
output_q
,
output_s
,
is_static
)
torch
.
ops
.
sgl_kernel
.
sgl_per_tensor_quant_fp8
.
default
(
input
,
output_q
,
output_s
,
is_static
)
def
cublas_grouped_gemm
(
def
cublas_grouped_gemm
(
...
@@ -129,7 +131,7 @@ def cublas_grouped_gemm(
...
@@ -129,7 +131,7 @@ def cublas_grouped_gemm(
len
(
inputs
)
>
0
and
len
(
weights
)
>
0
and
len
(
outputs
)
>
0
len
(
inputs
)
>
0
and
len
(
weights
)
>
0
and
len
(
outputs
)
>
0
),
"Inputs/weights/outputs should not be empty!"
),
"Inputs/weights/outputs should not be empty!"
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernel
.
cublas_grouped_gemm
(
torch
.
ops
.
sgl_kernel
.
cublas_grouped_gemm
.
default
(
inputs
,
inputs
,
weights
,
weights
,
outputs
,
outputs
,
...
@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8(
...
@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8(
output_q
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_quant_fp8
(
input
,
output_q
,
output_s
)
torch
.
ops
.
sgl_kernel
.
sgl_per_token_quant_fp8
.
default
(
input
,
output_q
,
output_s
)
def
cutlass_scaled_fp4_mm
(
def
cutlass_scaled_fp4_mm
(
...
@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm(
...
@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm(
assert
a
.
ndim
==
2
and
b
.
ndim
==
2
assert
a
.
ndim
==
2
and
b
.
ndim
==
2
m
,
n
=
a
.
shape
[
0
],
b
.
shape
[
0
]
m
,
n
=
a
.
shape
[
0
],
b
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
sgl_kernel
s
.
cutlass_scaled_fp4_mm
(
torch
.
ops
.
sgl_kernel
.
cutlass_scaled_fp4_mm
.
default
(
out
,
a
,
b
,
block_scale_a
,
block_scale_b
,
alpha
out
,
a
,
b
,
block_scale_a
,
block_scale_b
,
alpha
)
)
return
out
return
out
...
@@ -210,7 +212,7 @@ def scaled_fp4_quant(
...
@@ -210,7 +212,7 @@ def scaled_fp4_quant(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
)
torch
.
ops
.
sgl_kernel
s
.
scaled_fp4_quant
(
torch
.
ops
.
sgl_kernel
.
scaled_fp4_quant
.
default
(
output
,
input
,
output_scale
,
input_global_scale
output
,
input
,
output_scale
,
input_global_scale
)
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
31dfff7d
...
@@ -11,7 +11,7 @@ def moe_align_block_size(
...
@@ -11,7 +11,7 @@ def moe_align_block_size(
token_cnts_buffer
,
token_cnts_buffer
,
cumsum_buffer
,
cumsum_buffer
,
):
):
torch
.
ops
.
sgl_kernel
.
moe_align_block_size
(
torch
.
ops
.
sgl_kernel
.
moe_align_block_size
.
default
(
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
block_size
,
block_size
,
...
@@ -29,6 +29,6 @@ def topk_softmax(
...
@@ -29,6 +29,6 @@ def topk_softmax(
token_expert_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
float
,
gating_output
:
float
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
topk_softmax
(
torch
.
ops
.
sgl_kernel
.
topk_softmax
.
default
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
)
)
sgl-kernel/python/sgl_kernel/sampling.py
View file @
31dfff7d
...
@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
...
@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
probs
=
probs
.
float
()
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernel
.
top_k_renorm_probs
(
torch
.
ops
.
sgl_kernel
.
top_k_renorm_probs
.
default
(
probs
,
probs
,
renorm_probs
,
renorm_probs
,
maybe_top_k_arr
,
maybe_top_k_arr
,
...
@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal(
...
@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal(
probs
=
probs
.
float
()
probs
=
probs
.
float
()
maybe_top_p_arr
=
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
maybe_top_p_arr
=
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernel
.
top_p_renorm_probs
(
torch
.
ops
.
sgl_kernel
.
top_p_renorm_probs
.
default
(
probs
,
probs
,
renorm_probs
,
renorm_probs
,
maybe_top_p_arr
,
maybe_top_p_arr
,
...
@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
...
@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
)
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
top_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
top_p_sampling_from_probs
.
default
(
probs
,
probs
,
uniform_samples
,
uniform_samples
,
samples
,
samples
,
...
@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
...
@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
)
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
top_k_top_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
top_k_top_p_sampling_from_probs
.
default
(
probs
,
probs
,
uniform_samples
,
uniform_samples
,
samples
,
samples
,
...
@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
...
@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
)
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
min_p_sampling_from_probs
(
torch
.
ops
.
sgl_kernel
.
min_p_sampling_from_probs
.
default
(
probs
,
probs
,
uniform_samples
,
uniform_samples
,
samples
,
samples
,
...
...
sgl-kernel/python/sgl_kernel/speculative.py
View file @
31dfff7d
...
@@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only(
...
@@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only(
threshold_acc
:
float
=
1.0
,
threshold_acc
:
float
=
1.0
,
deterministic
:
bool
=
True
,
deterministic
:
bool
=
True
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
tree_speculative_sampling_target_only
(
torch
.
ops
.
sgl_kernel
.
tree_speculative_sampling_target_only
.
default
(
predicts
,
predicts
,
accept_index
,
accept_index
,
accept_token_num
,
accept_token_num
,
...
@@ -45,7 +45,7 @@ def verify_tree_greedy(
...
@@ -45,7 +45,7 @@ def verify_tree_greedy(
retrive_next_sibling
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
target_predict
:
torch
.
Tensor
,
target_predict
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
verify_tree_greedy
(
torch
.
ops
.
sgl_kernel
.
verify_tree_greedy
.
default
(
predicts
,
predicts
,
accept_index
,
accept_index
,
accept_token_num
,
accept_token_num
,
...
@@ -71,7 +71,7 @@ def build_tree_kernel_efficient(
...
@@ -71,7 +71,7 @@ def build_tree_kernel_efficient(
depth
:
int
,
depth
:
int
,
draft_token_num
:
int
,
draft_token_num
:
int
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
build_tree_kernel_efficient
(
torch
.
ops
.
sgl_kernel
.
build_tree_kernel_efficient
.
default
(
parent_list
,
parent_list
,
selected_index
,
selected_index
,
verified_seq_len
,
verified_seq_len
,
...
@@ -92,7 +92,7 @@ def segment_packbits(
...
@@ -92,7 +92,7 @@ def segment_packbits(
output_indptr
:
torch
.
Tensor
,
output_indptr
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
segment_packbits
(
torch
.
ops
.
sgl_kernel
.
segment_packbits
.
default
(
x
,
x
,
input_indptr
,
input_indptr
,
output_indptr
,
output_indptr
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment