Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
110e0066
Unverified
Commit
110e0066
authored
Mar 03, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 03, 2025
Browse files
Reorganize python source files in sgl-kernel with multiple files (#4027)
parent
6b45a21d
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
705 additions
and
781 deletions
+705
-781
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+33
-101
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+0
-677
sgl-kernel/src/sgl-kernel/ops/activation.py
sgl-kernel/src/sgl-kernel/ops/activation.py
+153
-0
sgl-kernel/src/sgl-kernel/ops/allreduce.py
sgl-kernel/src/sgl-kernel/ops/allreduce.py
+78
-0
sgl-kernel/src/sgl-kernel/ops/attention.py
sgl-kernel/src/sgl-kernel/ops/attention.py
+8
-0
sgl-kernel/src/sgl-kernel/ops/gemm.py
sgl-kernel/src/sgl-kernel/ops/gemm.py
+111
-0
sgl-kernel/src/sgl-kernel/ops/moe.py
sgl-kernel/src/sgl-kernel/ops/moe.py
+24
-0
sgl-kernel/src/sgl-kernel/ops/sampling.py
sgl-kernel/src/sgl-kernel/ops/sampling.py
+211
-0
sgl-kernel/src/sgl-kernel/ops/speculative.py
sgl-kernel/src/sgl-kernel/ops/speculative.py
+84
-0
sgl-kernel/src/sgl-kernel/ops/utils.py
sgl-kernel/src/sgl-kernel/ops/utils.py
+2
-2
sgl-kernel/tests/test_trt_allreduce.py
sgl-kernel/tests/test_trt_allreduce.py
+1
-1
No files found.
sgl-kernel/src/sgl-kernel/__init__.py
View file @
110e0066
...
...
@@ -9,105 +9,37 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
mode
=
ctypes
.
RTLD_GLOBAL
,
)
from
sgl_kernel.ops.activation
import
(
apply_rope_with_cos_sin_cache_inplace
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
rmsnorm
,
silu_and_mul
,
)
from
sgl_kernel.ops.allreduce
import
*
from
sgl_kernel.ops.attention
import
lightning_attention_decode
from
sgl_kernel.ops.gemm
import
(
bmm_fp8
,
cublas_grouped_gemm
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
int8_scaled_mm
,
sgl_per_token_group_quant_fp8
,
)
from
sgl_kernel.ops.moe
import
moe_align_block_size
from
sgl_kernel.ops.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
top_p_sampling_from_probs
,
)
from
sgl_kernel.ops.speculative
import
(
build_tree_kernel
,
build_tree_kernel_efficient
,
tree_speculative_sampling_target_only
,
)
from
sgl_kernel.version
import
__version__
if
torch
.
version
.
cuda
:
from
sgl_kernel.ops
import
(
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
build_tree_kernel
,
build_tree_kernel_efficient
,
cublas_grouped_gemm
,
custom_dispose
,
custom_reduce
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
init_custom_reduce
,
int8_scaled_mm
,
lightning_attention_decode
,
min_p_sampling_from_probs
,
moe_align_block_size
,
register_graph_buffers
,
rmsnorm
,
sampling_scaling_penalties
,
sgl_per_token_group_quant_fp8
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
)
else
:
assert
torch
.
version
.
hip
from
sgl_kernel.ops
import
(
all_reduce_reg
,
all_reduce_unreg
,
allocate_meta_buffer
,
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
dispose
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
get_meta_buffer_ipc_handle
,
init_custom_ar
,
int8_scaled_mm
,
lightning_attention_decode
,
meta_size
,
min_p_sampling_from_probs
,
moe_align_block_size
,
register_buffer
,
register_graph_buffers
,
rmsnorm
,
sampling_scaling_penalties
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
__all__
=
[
"__version__"
,
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"cublas_grouped_gemm"
,
"custom_dispose"
,
"custom_reduce"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel"
,
"fp8_blockwise_scaled_mm"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"init_custom_reduce"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"register_graph_buffers"
,
"rmsnorm"
,
"sampling_scaling_penalties"
,
"sgl_per_token_group_quant_fp8"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
"tree_speculative_sampling_target_only"
,
]
sgl-kernel/src/sgl-kernel/ops/__init__.py
deleted
100644 → 0
View file @
6b45a21d
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/ops/activation.py
0 → 100644
View file @
110e0066
from
typing
import
Optional
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
get_cuda_stream
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def
rmsnorm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernels
.
rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
return
out
def
fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
torch
.
ops
.
sgl_kernels
.
fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
)
def
gemma_rmsnorm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernels
.
gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
return
out
def
gemma_fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
torch
.
ops
.
sgl_kernels
.
gemma_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
()
)
def
_check_shape
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
)
->
None
:
assert
input
.
ndim
==
output
.
ndim
,
f
"
{
input
.
ndim
}
!=
{
output
.
ndim
}
"
assert
(
input
.
shape
[:
-
1
]
==
output
.
shape
[:
-
1
]
),
f
"
{
input
.
shape
[:
-
1
]
}
!=
{
output
.
shape
[:
-
1
]
}
"
assert
(
input
.
shape
[
-
1
]
==
2
*
output
.
shape
[
-
1
]
),
f
"
{
input
.
shape
[
-
1
]
}
!=
{
2
*
output
.
shape
[
-
1
]
}
"
def
silu_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernels
.
silu_and_mul
(
out
,
input
,
get_cuda_stream
())
return
out
def
gelu_tanh_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernels
.
gelu_tanh_and_mul
(
out
,
input
,
get_cuda_stream
())
return
out
def
gelu_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernels
.
gelu_and_mul
(
out
,
input
,
get_cuda_stream
())
return
out
def
apply_rope_with_cos_sin_cache_inplace
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
=
True
,
)
->
None
:
r
"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
positions
=
positions
.
int
()
torch
.
ops
.
sgl_kernels
.
apply_rope_pos_ids_cos_sin_cache
(
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k_rope
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
cos_sin_cache
=
cos_sin_cache
,
pos_ids
=
positions
,
interleave
=
(
not
is_neox
),
cuda_stream
=
get_cuda_stream
(),
)
sgl-kernel/src/sgl-kernel/ops/allreduce.py
0 → 100644
View file @
110e0066
from
typing
import
List
,
Tuple
import
sgl_kernel.ops._kernels
import
torch
if
torch
.
version
.
hip
is
not
None
:
# ROCM custom allreduce
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernels
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
torch
.
ops
.
sgl_kernels
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# TRTLLM custom allreduce
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
,
)
def
custom_dispose
(
fa
):
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
torch
.
ops
.
sgl_kernels
.
all_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
sgl-kernel/src/sgl-kernel/ops/attention.py
0 → 100644
View file @
110e0066
import
sgl_kernel.ops._kernels
import
torch
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
torch
.
ops
.
sgl_kernels
.
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
sgl-kernel/src/sgl-kernel/ops/gemm.py
0 → 100644
View file @
110e0066
from
typing
import
List
,
Optional
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
_get_cache_buf
,
get_cuda_stream
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
return
torch
.
ops
.
sgl_kernels
.
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
)
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
def
_bmm_fp8_internal
(
workspace_buffer
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
D
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
)
->
None
:
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernels
.
bmm_fp8
(
A
,
B
,
D
,
A_scale
,
B_scale
,
workspace_buffer
,
cublas_handle
,
get_cuda_stream
(),
)
def
bmm_fp8
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty
(
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
2
]),
device
=
A
.
device
,
dtype
=
dtype
,
)
workspace_buffer
=
_get_cache_buf
(
"bmm_fp8_workspace"
,
32
*
1024
*
1024
,
A
.
device
)
_bmm_fp8_internal
(
workspace_buffer
,
A
,
B
,
out
,
A_scale
,
B_scale
)
return
out
def
sgl_per_token_group_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
sgl_per_token_group_quant_fp8
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
def
cublas_grouped_gemm
(
inputs
:
List
[
torch
.
Tensor
],
weights
:
List
[
torch
.
Tensor
],
outputs
:
List
[
torch
.
Tensor
],
out_dtype
:
torch
.
dtype
,
)
->
None
:
assert
(
len
(
inputs
)
>
0
and
len
(
weights
)
>
0
and
len
(
outputs
)
>
0
),
"Inputs/weights/outputs should not be empty!"
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernels
.
cublas_grouped_gemm
(
inputs
,
weights
,
outputs
,
out_dtype
,
cublas_handle
,
get_cuda_stream
(),
)
sgl-kernel/src/sgl-kernel/ops/moe.py
0 → 100644
View file @
110e0066
import
sgl_kernel.ops._kernels
import
torch
def
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
):
torch
.
ops
.
sgl_kernels
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
)
sgl-kernel/src/sgl-kernel/ops/sampling.py
0 → 100644
View file @
110e0066
from
typing
import
Optional
,
Tuple
,
Union
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
_to_tensor_scalar_tuple
,
get_cuda_stream
def
_top_k_renorm_probs_internal
(
probs
:
torch
.
Tensor
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
)
->
torch
.
Tensor
:
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernels
.
top_k_renorm_probs_wrapper
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
top_k_val
,
get_cuda_stream
(),
)
return
renorm_probs
def
top_k_renorm_probs
(
probs
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
)
->
torch
.
Tensor
:
return
_top_k_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_k
))
top_k_renorm_prob
=
top_k_renorm_probs
def
_top_p_renorm_probs_internal
(
probs
:
torch
.
Tensor
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
)
->
torch
.
Tensor
:
probs
=
probs
.
float
()
maybe_top_p_arr
=
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernels
.
top_p_renorm_probs
(
probs
,
renorm_probs
,
maybe_top_p_arr
,
top_p_val
,
get_cuda_stream
(),
)
return
renorm_probs
def
top_p_renorm_probs
(
probs
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
)
->
torch
.
Tensor
:
return
_top_p_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_p
))
top_p_renorm_prob
=
top_p_renorm_probs
def
_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
deterministic
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
top_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
success
,
maybe_top_p_arr
,
top_p_val
,
deterministic
,
get_cuda_stream
(),
)
return
samples
,
success
def
top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
)
def
_top_k_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
deterministic
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
success
,
maybe_top_k_arr
,
top_k_val
,
maybe_top_p_arr
,
top_p_val
,
deterministic
,
get_cuda_stream
(),
)
return
samples
,
success
def
top_k_top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
top_p
:
Union
[
torch
.
Tensor
,
float
],
filter_apply_order
:
str
=
"top_k_first"
,
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
filter_apply_order
==
"top_k_first"
:
renorm_probs
=
top_k_renorm_probs
(
probs
,
top_k
)
return
top_p_sampling_from_probs
(
renorm_probs
,
uniform_samples
,
top_p
,
deterministic
,
check_nan
=
check_nan
)
elif
filter_apply_order
==
"joint"
:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_k_top_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
top_k
),
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
,
)
else
:
raise
ValueError
(
f
"Invalid filter_apply_order:
{
filter_apply_order
}
"
)
def
_min_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_min_p_arr
:
Optional
[
torch
.
Tensor
],
min_p_val
:
float
,
deterministic
:
bool
,
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_min_p_arr
=
(
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
min_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
maybe_min_p_arr
,
min_p_val
,
deterministic
,
get_cuda_stream
(),
)
return
samples
def
min_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
min_p
:
Union
[
torch
.
Tensor
,
float
],
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
uniform_samples
.
dim
()
==
2
:
# Take the first row (round) of uniform_samples
uniform_samples
=
uniform_samples
[
0
]
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_min_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
min_p
),
deterministic
)
sgl-kernel/src/sgl-kernel/ops/speculative.py
0 → 100644
View file @
110e0066
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
get_cuda_stream
def
tree_speculative_sampling_target_only
(
predicts
:
torch
.
Tensor
,
# mutable
accept_index
:
torch
.
Tensor
,
# mutable
accept_token_num
:
torch
.
Tensor
,
# mutable
candidates
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
deterministic
:
bool
=
True
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
tree_speculative_sampling_target_only
(
predicts
,
accept_index
,
accept_token_num
,
candidates
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
uniform_samples
,
target_probs
,
draft_probs
,
deterministic
,
get_cuda_stream
(),
)
def
build_tree_kernel_efficient
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
build_tree_kernel_efficient
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
depth
,
draft_token_num
,
)
def
build_tree_kernel
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
build_tree_kernel
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
topk
,
depth
,
draft_token_num
,
)
sgl-kernel/src/sgl-kernel/ops/utils.py
View file @
110e0066
...
...
@@ -18,8 +18,8 @@ from typing import Dict, Tuple
import
torch
def
_
get_cuda_stream
(
device
:
torch
.
device
)
->
int
:
return
torch
.
cuda
.
current_stream
(
device
).
cuda_stream
def
get_cuda_stream
()
->
int
:
return
torch
.
cuda
.
current_stream
().
cuda_stream
_cache_buf
:
Dict
[
Tuple
[
str
,
torch
.
device
],
torch
.
Tensor
]
=
{}
...
...
sgl-kernel/tests/test_trt_allreduce.py
View file @
110e0066
...
...
@@ -7,9 +7,9 @@ import unittest
from
typing
import
Any
,
List
,
Optional
import
ray
import
sgl_kernel.ops.allreduce
as
custom_ops
import
torch
import
torch.distributed
as
dist
from
sgl_kernel
import
ops
as
custom_ops
from
torch.distributed
import
ProcessGroup
from
vllm
import
_custom_ops
as
vllm_ops
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment