Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
110e0066
"sgl-kernel/python/sgl_kernel/__init__.py" did not exist on "47eb139f810a84f16d426087268991bef8a4540f"
Unverified
Commit
110e0066
authored
Mar 03, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 03, 2025
Browse files
Reorganize python source files in sgl-kernel with multiple files (#4027)
parent
6b45a21d
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
705 additions
and
781 deletions
+705
-781
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+33
-101
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+0
-677
sgl-kernel/src/sgl-kernel/ops/activation.py
sgl-kernel/src/sgl-kernel/ops/activation.py
+153
-0
sgl-kernel/src/sgl-kernel/ops/allreduce.py
sgl-kernel/src/sgl-kernel/ops/allreduce.py
+78
-0
sgl-kernel/src/sgl-kernel/ops/attention.py
sgl-kernel/src/sgl-kernel/ops/attention.py
+8
-0
sgl-kernel/src/sgl-kernel/ops/gemm.py
sgl-kernel/src/sgl-kernel/ops/gemm.py
+111
-0
sgl-kernel/src/sgl-kernel/ops/moe.py
sgl-kernel/src/sgl-kernel/ops/moe.py
+24
-0
sgl-kernel/src/sgl-kernel/ops/sampling.py
sgl-kernel/src/sgl-kernel/ops/sampling.py
+211
-0
sgl-kernel/src/sgl-kernel/ops/speculative.py
sgl-kernel/src/sgl-kernel/ops/speculative.py
+84
-0
sgl-kernel/src/sgl-kernel/ops/utils.py
sgl-kernel/src/sgl-kernel/ops/utils.py
+2
-2
sgl-kernel/tests/test_trt_allreduce.py
sgl-kernel/tests/test_trt_allreduce.py
+1
-1
No files found.
sgl-kernel/src/sgl-kernel/__init__.py
View file @
110e0066
...
...
@@ -9,105 +9,37 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
mode
=
ctypes
.
RTLD_GLOBAL
,
)
from
sgl_kernel.ops.activation
import
(
apply_rope_with_cos_sin_cache_inplace
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
rmsnorm
,
silu_and_mul
,
)
from
sgl_kernel.ops.allreduce
import
*
from
sgl_kernel.ops.attention
import
lightning_attention_decode
from
sgl_kernel.ops.gemm
import
(
bmm_fp8
,
cublas_grouped_gemm
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
int8_scaled_mm
,
sgl_per_token_group_quant_fp8
,
)
from
sgl_kernel.ops.moe
import
moe_align_block_size
from
sgl_kernel.ops.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
top_p_sampling_from_probs
,
)
from
sgl_kernel.ops.speculative
import
(
build_tree_kernel
,
build_tree_kernel_efficient
,
tree_speculative_sampling_target_only
,
)
from
sgl_kernel.version
import
__version__
if
torch
.
version
.
cuda
:
from
sgl_kernel.ops
import
(
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
build_tree_kernel
,
build_tree_kernel_efficient
,
cublas_grouped_gemm
,
custom_dispose
,
custom_reduce
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
init_custom_reduce
,
int8_scaled_mm
,
lightning_attention_decode
,
min_p_sampling_from_probs
,
moe_align_block_size
,
register_graph_buffers
,
rmsnorm
,
sampling_scaling_penalties
,
sgl_per_token_group_quant_fp8
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
)
else
:
assert
torch
.
version
.
hip
from
sgl_kernel.ops
import
(
all_reduce_reg
,
all_reduce_unreg
,
allocate_meta_buffer
,
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
dispose
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
get_meta_buffer_ipc_handle
,
init_custom_ar
,
int8_scaled_mm
,
lightning_attention_decode
,
meta_size
,
min_p_sampling_from_probs
,
moe_align_block_size
,
register_buffer
,
register_graph_buffers
,
rmsnorm
,
sampling_scaling_penalties
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
__all__
=
[
"__version__"
,
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"cublas_grouped_gemm"
,
"custom_dispose"
,
"custom_reduce"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel"
,
"fp8_blockwise_scaled_mm"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"init_custom_reduce"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"register_graph_buffers"
,
"rmsnorm"
,
"sampling_scaling_penalties"
,
"sgl_per_token_group_quant_fp8"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
"tree_speculative_sampling_target_only"
,
]
sgl-kernel/src/sgl-kernel/ops/__init__.py
deleted
100644 → 0
View file @
6b45a21d
import
os
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
(
_get_cache_buf
,
_get_cuda_stream
,
_to_tensor_scalar_tuple
,
)
def
apply_rope_with_cos_sin_cache_inplace
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
=
True
,
)
->
None
:
r
"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
with
query
.
device
as
device
:
positions
=
positions
.
int
()
torch
.
ops
.
sgl_kernels
.
apply_rope_pos_ids_cos_sin_cache
(
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k_rope
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
cos_sin_cache
=
cos_sin_cache
,
pos_ids
=
positions
,
interleave
=
(
not
is_neox
),
cuda_stream
=
_get_cuda_stream
(
device
),
)
if
torch
.
version
.
hip
is
not
None
:
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernels
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
torch
.
ops
.
sgl_kernels
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# trt_reduce
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
,
)
def
custom_dispose
(
fa
):
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
torch
.
ops
.
sgl_kernels
.
all_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
):
torch
.
ops
.
sgl_kernels
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
)
def
sampling_scaling_penalties
(
logits
,
scaling_penalties
):
return
torch
.
ops
.
sgl_kernels
.
sampling_scaling_penalties
(
logits
,
scaling_penalties
)
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
return
torch
.
ops
.
sgl_kernels
.
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
)
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
torch
.
ops
.
sgl_kernels
.
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def
rmsnorm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
with
input
.
device
as
device
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernels
.
rmsnorm
(
out
,
input
,
weight
,
eps
,
_get_cuda_stream
(
device
))
return
out
def
fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
with
input
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
)
def
gemma_rmsnorm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
with
input
.
device
as
device
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernels
.
gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
_get_cuda_stream
(
device
)
)
return
out
def
gemma_fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
with
input
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
gemma_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
_get_cuda_stream
(
device
)
)
def
_check_shape
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
)
->
None
:
assert
input
.
ndim
==
output
.
ndim
,
f
"
{
input
.
ndim
}
!=
{
output
.
ndim
}
"
assert
(
input
.
shape
[:
-
1
]
==
output
.
shape
[:
-
1
]
),
f
"
{
input
.
shape
[:
-
1
]
}
!=
{
output
.
shape
[:
-
1
]
}
"
assert
(
input
.
shape
[
-
1
]
==
2
*
output
.
shape
[
-
1
]
),
f
"
{
input
.
shape
[
-
1
]
}
!=
{
2
*
output
.
shape
[
-
1
]
}
"
def
silu_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
with
input
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
silu_and_mul
(
out
,
input
,
_get_cuda_stream
(
device
))
return
out
def
gelu_tanh_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
with
input
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
gelu_tanh_and_mul
(
out
,
input
,
_get_cuda_stream
(
device
))
return
out
def
gelu_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
with
input
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
gelu_and_mul
(
out
,
input
,
_get_cuda_stream
(
device
))
return
out
def
_bmm_fp8_internal
(
workspace_buffer
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
D
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
)
->
None
:
with
A
.
device
as
device
:
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernels
.
bmm_fp8
(
A
,
B
,
D
,
A_scale
,
B_scale
,
workspace_buffer
,
cublas_handle
,
_get_cuda_stream
(
device
),
)
def
bmm_fp8
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty
(
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
2
]),
device
=
A
.
device
,
dtype
=
dtype
,
)
workspace_buffer
=
_get_cache_buf
(
"bmm_fp8_workspace"
,
32
*
1024
*
1024
,
A
.
device
)
_bmm_fp8_internal
(
workspace_buffer
,
A
,
B
,
out
,
A_scale
,
B_scale
)
return
out
def
_top_k_renorm_probs_internal
(
probs
:
torch
.
Tensor
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernels
.
top_k_renorm_probs_wrapper
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
top_k_val
,
_get_cuda_stream
(
device
),
)
return
renorm_probs
def
top_k_renorm_probs
(
probs
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
)
->
torch
.
Tensor
:
return
_top_k_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_k
))
top_k_renorm_prob
=
top_k_renorm_probs
def
_top_p_renorm_probs_internal
(
probs
:
torch
.
Tensor
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernels
.
top_p_renorm_probs
(
probs
,
renorm_probs
,
maybe_top_p_arr
,
top_p_val
,
_get_cuda_stream
(
device
),
)
return
renorm_probs
def
top_p_renorm_probs
(
probs
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
)
->
torch
.
Tensor
:
return
_top_p_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_p
))
top_p_renorm_prob
=
top_p_renorm_probs
def
_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
deterministic
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
top_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
success
,
maybe_top_p_arr
,
top_p_val
,
deterministic
,
_get_cuda_stream
(
device
),
)
return
samples
,
success
def
top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
)
def
_top_k_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
deterministic
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
success
,
maybe_top_k_arr
,
top_k_val
,
maybe_top_p_arr
,
top_p_val
,
deterministic
,
_get_cuda_stream
(
device
),
)
return
samples
,
success
def
top_k_top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
top_p
:
Union
[
torch
.
Tensor
,
float
],
filter_apply_order
:
str
=
"top_k_first"
,
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
filter_apply_order
==
"top_k_first"
:
renorm_probs
=
top_k_renorm_probs
(
probs
,
top_k
)
return
top_p_sampling_from_probs
(
renorm_probs
,
uniform_samples
,
top_p
,
deterministic
,
check_nan
=
check_nan
)
elif
filter_apply_order
==
"joint"
:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_k_top_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
top_k
),
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
,
)
else
:
raise
ValueError
(
f
"Invalid filter_apply_order:
{
filter_apply_order
}
"
)
def
_min_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_min_p_arr
:
Optional
[
torch
.
Tensor
],
min_p_val
:
float
,
deterministic
:
bool
,
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_min_p_arr
=
(
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
min_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
maybe_min_p_arr
,
min_p_val
,
deterministic
,
_get_cuda_stream
(
device
),
)
return
samples
def
min_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
min_p
:
Union
[
torch
.
Tensor
,
float
],
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
uniform_samples
.
dim
()
==
2
:
# Take the first row (round) of uniform_samples
uniform_samples
=
uniform_samples
[
0
]
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_min_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
min_p
),
deterministic
)
def
tree_speculative_sampling_target_only
(
predicts
:
torch
.
Tensor
,
# mutable
accept_index
:
torch
.
Tensor
,
# mutable
accept_token_num
:
torch
.
Tensor
,
# mutable
candidates
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
deterministic
:
bool
=
True
,
)
->
None
:
with
predicts
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
tree_speculative_sampling_target_only
(
predicts
,
accept_index
,
accept_token_num
,
candidates
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
uniform_samples
,
target_probs
,
draft_probs
,
deterministic
,
_get_cuda_stream
(
device
),
)
def
build_tree_kernel_efficient
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
with
parent_list
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
build_tree_kernel_efficient
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
depth
,
draft_token_num
,
)
def
build_tree_kernel
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
with
parent_list
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
build_tree_kernel
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
topk
,
depth
,
draft_token_num
,
)
def
sgl_per_token_group_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
sgl_per_token_group_quant_fp8
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
def
cublas_grouped_gemm
(
inputs
:
List
[
torch
.
Tensor
],
weights
:
List
[
torch
.
Tensor
],
outputs
:
List
[
torch
.
Tensor
],
out_dtype
:
torch
.
dtype
,
)
->
None
:
with
inputs
[
0
].
device
as
device
:
assert
(
len
(
inputs
)
>
0
and
len
(
weights
)
>
0
and
len
(
outputs
)
>
0
),
"Inputs/weights/outputs should not be empty!"
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernels
.
cublas_grouped_gemm
(
inputs
,
weights
,
outputs
,
out_dtype
,
cublas_handle
,
_get_cuda_stream
(
device
),
)
sgl-kernel/src/sgl-kernel/ops/activation.py
0 → 100644
View file @
110e0066
from
typing
import
Optional
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
get_cuda_stream
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def
rmsnorm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernels
.
rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
return
out
def
fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
torch
.
ops
.
sgl_kernels
.
fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
)
def
gemma_rmsnorm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernels
.
gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
())
return
out
def
gemma_fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
torch
.
ops
.
sgl_kernels
.
gemma_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
()
)
def
_check_shape
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
)
->
None
:
assert
input
.
ndim
==
output
.
ndim
,
f
"
{
input
.
ndim
}
!=
{
output
.
ndim
}
"
assert
(
input
.
shape
[:
-
1
]
==
output
.
shape
[:
-
1
]
),
f
"
{
input
.
shape
[:
-
1
]
}
!=
{
output
.
shape
[:
-
1
]
}
"
assert
(
input
.
shape
[
-
1
]
==
2
*
output
.
shape
[
-
1
]
),
f
"
{
input
.
shape
[
-
1
]
}
!=
{
2
*
output
.
shape
[
-
1
]
}
"
def
silu_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernels
.
silu_and_mul
(
out
,
input
,
get_cuda_stream
())
return
out
def
gelu_tanh_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernels
.
gelu_tanh_and_mul
(
out
,
input
,
get_cuda_stream
())
return
out
def
gelu_and_mul
(
input
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
input
.
shape
[
-
1
]
*
input
.
dtype
.
itemsize
%
16
!=
0
:
raise
ValueError
(
"The pointers must be multiple of 16 bytes."
)
if
out
is
not
None
:
_check_shape
(
input
,
out
)
else
:
out
=
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
input
.
shape
[
-
1
]
//
2
,),
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
torch
.
ops
.
sgl_kernels
.
gelu_and_mul
(
out
,
input
,
get_cuda_stream
())
return
out
def
apply_rope_with_cos_sin_cache_inplace
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
=
True
,
)
->
None
:
r
"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
positions
=
positions
.
int
()
torch
.
ops
.
sgl_kernels
.
apply_rope_pos_ids_cos_sin_cache
(
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k_rope
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
cos_sin_cache
=
cos_sin_cache
,
pos_ids
=
positions
,
interleave
=
(
not
is_neox
),
cuda_stream
=
get_cuda_stream
(),
)
sgl-kernel/src/sgl-kernel/ops/allreduce.py
0 → 100644
View file @
110e0066
from
typing
import
List
,
Tuple
import
sgl_kernel.ops._kernels
import
torch
if
torch
.
version
.
hip
is
not
None
:
# ROCM custom allreduce
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernels
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
torch
.
ops
.
sgl_kernels
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# TRTLLM custom allreduce
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
,
)
def
custom_dispose
(
fa
):
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
torch
.
ops
.
sgl_kernels
.
all_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
sgl-kernel/src/sgl-kernel/ops/attention.py
0 → 100644
View file @
110e0066
import
sgl_kernel.ops._kernels
import
torch
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
torch
.
ops
.
sgl_kernels
.
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
sgl-kernel/src/sgl-kernel/ops/gemm.py
0 → 100644
View file @
110e0066
from
typing
import
List
,
Optional
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
_get_cache_buf
,
get_cuda_stream
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
def
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
):
return
torch
.
ops
.
sgl_kernels
.
fp8_blockwise_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
)
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
def
_bmm_fp8_internal
(
workspace_buffer
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
D
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
)
->
None
:
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernels
.
bmm_fp8
(
A
,
B
,
D
,
A_scale
,
B_scale
,
workspace_buffer
,
cublas_handle
,
get_cuda_stream
(),
)
def
bmm_fp8
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty
(
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
2
]),
device
=
A
.
device
,
dtype
=
dtype
,
)
workspace_buffer
=
_get_cache_buf
(
"bmm_fp8_workspace"
,
32
*
1024
*
1024
,
A
.
device
)
_bmm_fp8_internal
(
workspace_buffer
,
A
,
B
,
out
,
A_scale
,
B_scale
)
return
out
def
sgl_per_token_group_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
sgl_per_token_group_quant_fp8
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
def
cublas_grouped_gemm
(
inputs
:
List
[
torch
.
Tensor
],
weights
:
List
[
torch
.
Tensor
],
outputs
:
List
[
torch
.
Tensor
],
out_dtype
:
torch
.
dtype
,
)
->
None
:
assert
(
len
(
inputs
)
>
0
and
len
(
weights
)
>
0
and
len
(
outputs
)
>
0
),
"Inputs/weights/outputs should not be empty!"
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
torch
.
ops
.
sgl_kernels
.
cublas_grouped_gemm
(
inputs
,
weights
,
outputs
,
out_dtype
,
cublas_handle
,
get_cuda_stream
(),
)
sgl-kernel/src/sgl-kernel/ops/moe.py
0 → 100644
View file @
110e0066
import
sgl_kernel.ops._kernels
import
torch
def
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
):
torch
.
ops
.
sgl_kernels
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
)
sgl-kernel/src/sgl-kernel/ops/sampling.py
0 → 100644
View file @
110e0066
from
typing
import
Optional
,
Tuple
,
Union
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
_to_tensor_scalar_tuple
,
get_cuda_stream
def
_top_k_renorm_probs_internal
(
probs
:
torch
.
Tensor
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
)
->
torch
.
Tensor
:
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernels
.
top_k_renorm_probs_wrapper
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
top_k_val
,
get_cuda_stream
(),
)
return
renorm_probs
def
top_k_renorm_probs
(
probs
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
)
->
torch
.
Tensor
:
return
_top_k_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_k
))
top_k_renorm_prob
=
top_k_renorm_probs
def
_top_p_renorm_probs_internal
(
probs
:
torch
.
Tensor
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
)
->
torch
.
Tensor
:
probs
=
probs
.
float
()
maybe_top_p_arr
=
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernels
.
top_p_renorm_probs
(
probs
,
renorm_probs
,
maybe_top_p_arr
,
top_p_val
,
get_cuda_stream
(),
)
return
renorm_probs
def
top_p_renorm_probs
(
probs
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
)
->
torch
.
Tensor
:
return
_top_p_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_p
))
top_p_renorm_prob
=
top_p_renorm_probs
def
_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
deterministic
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
top_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
success
,
maybe_top_p_arr
,
top_p_val
,
deterministic
,
get_cuda_stream
(),
)
return
samples
,
success
def
top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
)
def
_top_k_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
deterministic
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
success
,
maybe_top_k_arr
,
top_k_val
,
maybe_top_p_arr
,
top_p_val
,
deterministic
,
get_cuda_stream
(),
)
return
samples
,
success
def
top_k_top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
top_p
:
Union
[
torch
.
Tensor
,
float
],
filter_apply_order
:
str
=
"top_k_first"
,
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
filter_apply_order
==
"top_k_first"
:
renorm_probs
=
top_k_renorm_probs
(
probs
,
top_k
)
return
top_p_sampling_from_probs
(
renorm_probs
,
uniform_samples
,
top_p
,
deterministic
,
check_nan
=
check_nan
)
elif
filter_apply_order
==
"joint"
:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_k_top_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
top_k
),
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
,
)
else
:
raise
ValueError
(
f
"Invalid filter_apply_order:
{
filter_apply_order
}
"
)
def
_min_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_min_p_arr
:
Optional
[
torch
.
Tensor
],
min_p_val
:
float
,
deterministic
:
bool
,
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_min_p_arr
=
(
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
ops
.
sgl_kernels
.
min_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
maybe_min_p_arr
,
min_p_val
,
deterministic
,
get_cuda_stream
(),
)
return
samples
def
min_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
min_p
:
Union
[
torch
.
Tensor
,
float
],
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
uniform_samples
.
dim
()
==
2
:
# Take the first row (round) of uniform_samples
uniform_samples
=
uniform_samples
[
0
]
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_min_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
min_p
),
deterministic
)
sgl-kernel/src/sgl-kernel/ops/speculative.py
0 → 100644
View file @
110e0066
import
sgl_kernel.ops._kernels
import
torch
from
sgl_kernel.ops.utils
import
get_cuda_stream
def
tree_speculative_sampling_target_only
(
predicts
:
torch
.
Tensor
,
# mutable
accept_index
:
torch
.
Tensor
,
# mutable
accept_token_num
:
torch
.
Tensor
,
# mutable
candidates
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
deterministic
:
bool
=
True
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
tree_speculative_sampling_target_only
(
predicts
,
accept_index
,
accept_token_num
,
candidates
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
uniform_samples
,
target_probs
,
draft_probs
,
deterministic
,
get_cuda_stream
(),
)
def
build_tree_kernel_efficient
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
build_tree_kernel_efficient
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
depth
,
draft_token_num
,
)
def
build_tree_kernel
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
build_tree_kernel
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
topk
,
depth
,
draft_token_num
,
)
sgl-kernel/src/sgl-kernel/ops/utils.py
View file @
110e0066
...
...
@@ -18,8 +18,8 @@ from typing import Dict, Tuple
import
torch
def
_
get_cuda_stream
(
device
:
torch
.
device
)
->
int
:
return
torch
.
cuda
.
current_stream
(
device
).
cuda_stream
def
get_cuda_stream
()
->
int
:
return
torch
.
cuda
.
current_stream
().
cuda_stream
_cache_buf
:
Dict
[
Tuple
[
str
,
torch
.
device
],
torch
.
Tensor
]
=
{}
...
...
sgl-kernel/tests/test_trt_allreduce.py
View file @
110e0066
...
...
@@ -7,9 +7,9 @@ import unittest
from
typing
import
Any
,
List
,
Optional
import
ray
import
sgl_kernel.ops.allreduce
as
custom_ops
import
torch
import
torch.distributed
as
dist
from
sgl_kernel
import
ops
as
custom_ops
from
torch.distributed
import
ProcessGroup
from
vllm
import
_custom_ops
as
vllm_ops
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment