Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20315697
Unverified
Commit
20315697
authored
Nov 02, 2025
by
Lianmin Zheng
Committed by
GitHub
Nov 02, 2025
Browse files
move all get_stream in sgl_kernel to c++ to reduce the launch overhead (#12521)
parent
c9db7911
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
60 additions
and
93 deletions
+60
-93
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+7
-19
python/sglang/srt/configs/falcon_h1.py
python/sglang/srt/configs/falcon_h1.py
+3
-2
python/sglang/srt/configs/nemotron_h.py
python/sglang/srt/configs/nemotron_h.py
+2
-1
python/sglang/srt/configs/qwen3_next.py
python/sglang/srt/configs/qwen3_next.py
+2
-1
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+8
-13
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+0
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+1
-0
sgl-kernel/README.md
sgl-kernel/README.md
+2
-2
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+7
-7
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+1
-1
sgl-kernel/csrc/elementwise/cast.cu
sgl-kernel/csrc/elementwise/cast.cu
+2
-3
sgl-kernel/csrc/elementwise/rope.cu
sgl-kernel/csrc/elementwise/rope.cu
+1
-2
sgl-kernel/csrc/gemm/bmm_fp8.cu
sgl-kernel/csrc/gemm/bmm_fp8.cu
+2
-3
sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/csrc/speculative/eagle_utils.cu
+2
-3
sgl-kernel/csrc/speculative/speculative_sampling.cu
sgl-kernel/csrc/speculative/speculative_sampling.cu
+2
-3
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+4
-9
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+13
-13
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+1
-2
sgl-kernel/python/sgl_kernel/speculative.py
sgl-kernel/python/sgl_kernel/speculative.py
+0
-3
sgl-kernel/python/sgl_kernel/utils.py
sgl-kernel/python/sgl_kernel/utils.py
+0
-5
No files found.
python/sglang/srt/_custom_ops.py
View file @
20315697
...
...
@@ -4,32 +4,20 @@ from typing import List, Optional, Tuple
import
torch
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
,
is_npu
from
sglang.srt.utils
import
is_hip
,
is_hpu
,
is_npu
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
get_bool_env_var
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
"false"
)
if
not
is_hpu
():
# ROCm does not use vllm custom allreduce
if
use_vllm_custom_allreduce
and
not
is_hip
():
try
:
import
vllm._C
# noqa: F401
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
else
:
try
:
import
sgl_kernel
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
try
:
import
sgl_kernel
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
not
is_hip
()
and
not
is_npu
():
if
use_vllm_custom_allreduce
:
custom_op
=
torch
.
ops
.
_C_custom_ar
else
:
custom_op
=
sgl_kernel
.
allreduce
custom_op
=
sgl_kernel
.
allreduce
# custom allreduce
def
init_custom_ar
(
...
...
python/sglang/srt/configs/falcon_h1.py
View file @
20315697
...
...
@@ -19,7 +19,6 @@ from transformers.configuration_utils import PretrainedConfig
from
transformers.utils
import
logging
from
sglang.srt.configs.mamba_utils
import
Mamba2CacheParams
,
Mamba2StateShape
from
sglang.srt.layers.dp_attention
import
get_tensor_model_parallel_world_size
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -297,8 +296,10 @@ class FalconH1Config(PretrainedConfig):
@
property
def
mamba2_cache_params
(
self
):
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
shape
=
Mamba2StateShape
.
create
(
tp_world_size
=
get_ten
sor_model_parallel_world
_size
(),
tp_world_size
=
get_
at
ten
tion_tp
_size
(),
intermediate_size
=
self
.
mamba_intermediate
,
n_groups
=
self
.
mamba_n_groups
,
num_heads
=
self
.
mamba_n_heads
,
...
...
python/sglang/srt/configs/nemotron_h.py
View file @
20315697
...
...
@@ -20,7 +20,6 @@ from transformers.configuration_utils import PretrainedConfig
from
transformers.utils
import
logging
from
sglang.srt.configs.mamba_utils
import
Mamba2CacheParams
,
Mamba2StateShape
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -273,6 +272,8 @@ class NemotronHConfig(PretrainedConfig):
@
property
def
mamba2_cache_params
(
self
)
->
Mamba2CacheParams
:
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
shape
=
Mamba2StateShape
.
create
(
tp_world_size
=
get_attention_tp_size
(),
intermediate_size
=
self
.
mamba_num_heads
*
self
.
mamba_head_dim
,
...
...
python/sglang/srt/configs/qwen3_next.py
View file @
20315697
...
...
@@ -21,7 +21,6 @@ from transformers.modeling_rope_utils import rope_config_validation
from
transformers.utils
import
logging
from
sglang.srt.configs.mamba_utils
import
Mamba2CacheParams
,
Mamba2StateShape
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -277,6 +276,8 @@ class Qwen3NextConfig(PretrainedConfig):
@
property
def
mamba2_cache_params
(
self
)
->
Mamba2CacheParams
:
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
shape
=
Mamba2StateShape
.
create
(
tp_world_size
=
get_attention_tp_size
(),
intermediate_size
=
self
.
linear_value_head_dim
*
self
.
linear_num_value_heads
,
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
20315697
...
...
@@ -21,24 +21,19 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from
sglang.srt.environ
import
envs
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
log_info_on_rank0
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
# Use vLLM custom allreduce
ops
.
meta_size
()
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
# noqa: F401
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
# noqa: F401
custom_ar
=
True
except
Exception
:
except
ImportError
:
# For CPUs
custom_ar
=
False
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/environ.py
View file @
20315697
...
...
@@ -229,7 +229,6 @@ class Envs:
SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK
=
EnvBool
(
False
)
# vLLM dependencies (TODO: they have been deprecated, we can remove them safely)
USE_VLLM_CUSTOM_ALLREDUCE
=
EnvBool
(
False
)
USE_VLLM_CUTLASS_W8A8_FP8_KERNEL
=
EnvBool
(
False
)
USE_TRITON_W8A8_FP8_KERNEL
=
EnvBool
(
False
)
...
...
python/sglang/srt/utils/common.py
View file @
20315697
...
...
@@ -303,6 +303,7 @@ def xpu_has_xmx_support():
return
False
@
lru_cache
(
maxsize
=
1
)
def
is_flashinfer_available
():
"""
Check whether flashinfer is available.
...
...
sgl-kernel/README.md
View file @
20315697
...
...
@@ -52,8 +52,8 @@ make build
```
cpp
// We need def with schema here for torch.compile
m
.
def
(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer,
int
"
"cublas_handle
, int cuda_stream
) -> ()"
);
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
"
int
cublas_handle) -> ()"
);
m
.
impl
(
"bmm_fp8"
,
torch
::
kCUDA
,
&
bmm_fp8
);
```
...
...
sgl-kernel/csrc/common_extension.cc
View file @
20315697
...
...
@@ -90,13 +90,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, bool enable_pdl,
int cuda_stream,
"
"Tensor pos_ids, bool interleave, bool enable_pdl, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"
);
m
.
impl
(
"apply_rope_pos_ids_cos_sin_cache"
,
torch
::
kCUDA
,
&
apply_rope_pos_ids_cos_sin_cache
);
m
.
def
(
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc,
int
"
"mult, int offset
, int cuda_stream
) -> ()"
);
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, "
"
int
mult, int offset) -> ()"
);
m
.
impl
(
"downcast_fp8"
,
torch
::
kCUDA
,
&
downcast_fp8
);
m
.
def
(
"copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"
);
...
...
@@ -303,13 +303,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"bool deterministic
, int cuda_stream
) -> ()"
);
"bool deterministic) -> ()"
);
m
.
impl
(
"tree_speculative_sampling_target_only"
,
torch
::
kCUDA
,
&
tree_speculative_sampling_target_only
);
m
.
def
(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict
, int cuda_stream
) -> ()"
);
"Tensor target_predict) -> ()"
);
m
.
impl
(
"verify_tree_greedy"
,
torch
::
kCUDA
,
&
verify_tree_greedy
);
m
.
def
(
...
...
@@ -403,8 +403,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From FlashInfer
*/
m
.
def
(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer,
int
"
"cublas_handle
, int cuda_stream
) -> ()"
,
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
"
int
cublas_handle) -> ()"
,
{
at
::
Tag
::
needs_fixed_stride_order
});
m
.
impl
(
"bmm_fp8"
,
torch
::
kCUDA
,
&
bmm_fp8
);
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
20315697
...
...
@@ -106,7 +106,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m
.
def
(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict
, int cuda_stream
) -> ()"
);
"Tensor target_predict) -> ()"
);
m
.
impl
(
"verify_tree_greedy"
,
torch
::
kCUDA
,
&
verify_tree_greedy
);
m
.
def
(
...
...
sgl-kernel/csrc/elementwise/cast.cu
View file @
20315697
...
...
@@ -150,14 +150,13 @@ void downcast_fp8(
at
::
Tensor
&
v_scale
,
at
::
Tensor
&
loc
,
int64_t
mult
,
int64_t
offset
,
int64_t
cuda_stream
)
{
int64_t
offset
)
{
CHECK_INPUT
(
k
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
k_out
);
CHECK_INPUT
(
v_out
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
switch
(
k
.
scalar_type
())
{
case
at
::
ScalarType
::
BFloat16
:
downcast_fp8_impl
<
__nv_bfloat16
>
(
k
,
v
,
k_out
,
v_out
,
k_scale
,
v_scale
,
loc
,
mult
,
offset
,
stream
);
...
...
sgl-kernel/csrc/elementwise/rope.cu
View file @
20315697
...
...
@@ -28,7 +28,6 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
pos_ids
,
bool
interleave
,
bool
enable_pdl
,
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
v_buffer
,
...
...
@@ -88,7 +87,7 @@ void apply_rope_pos_ids_cos_sin_cache(
size_t
k_rope_stride_n
=
k_rope
.
stride
(
0
);
size_t
k_rope_stride_h
=
k_rope
.
stride
(
1
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
q
.
scalar_type
(),
c_type
,
[
&
]
{
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
...
...
sgl-kernel/csrc/gemm/bmm_fp8.cu
View file @
20315697
...
...
@@ -27,8 +27,7 @@ void bmm_fp8(
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
int64_t
cublas_handle
)
{
TORCH_CHECK
(
A
.
is_cuda
(),
"A must be a CUDA tensor"
);
TORCH_CHECK
(
B
.
is_cuda
(),
"B must be a CUDA tensor"
);
TORCH_CHECK
(
D
.
is_cuda
(),
"D must be a CUDA tensor"
);
...
...
@@ -51,7 +50,7 @@ void bmm_fp8(
auto
n
=
B
.
size
(
2
);
auto
lt_handle
=
reinterpret_cast
<
cublasLtHandle_t
>
(
cublas_handle
);
auto
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
auto
status
=
flashinfer
::
bmm_fp8
::
bmm_fp8_internal_cublaslt
(
workspace_buffer
.
data_ptr
(),
...
...
sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
20315697
...
...
@@ -328,8 +328,7 @@ void verify_tree_greedy(
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
target_predict
,
int64_t
cuda_stream
=
0
)
{
at
::
Tensor
target_predict
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
...
...
@@ -389,7 +388,7 @@ void verify_tree_greedy(
throw
std
::
runtime_error
(
"Expected 'target_predict' to be of type long (torch.int64)."
);
}
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
dim3
grid
(
batch_size
);
dim3
block
(
1
);
...
...
sgl-kernel/csrc/speculative/speculative_sampling.cu
View file @
20315697
...
...
@@ -42,8 +42,7 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
draft_probs
,
double
threshold_single
,
double
threshold_acc
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
)
{
bool
deterministic
=
true
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
...
...
@@ -124,7 +123,7 @@ void tree_speculative_sampling_target_only(
CHECK_GE
(
threshold_acc
,
0
);
CHECK_GE
(
1
,
threshold_acc
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int32_t
,
int64_t
>
(
static_cast
<
int32_t
*>
(
predicts
.
data_ptr
()),
static_cast
<
int32_t
*>
(
accept_index
.
data_ptr
()),
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
20315697
...
...
@@ -152,7 +152,6 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
pos_ids
,
bool
interleave
,
bool
enable_pdl
,
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
v_buffer
,
...
...
@@ -167,8 +166,7 @@ void downcast_fp8(
at
::
Tensor
&
v_scale
,
at
::
Tensor
&
loc
,
int64_t
mult
,
int64_t
offset
,
int64_t
cuda_stream
);
int64_t
offset
);
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
concat_mla_k
(
torch
::
Tensor
k
,
torch
::
Tensor
k_nope
,
torch
::
Tensor
k_rope
);
...
...
@@ -253,8 +251,7 @@ void bmm_fp8(
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
int64_t
cublas_handle
);
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
);
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
torch
::
Tensor
const
&
mat_b
);
...
...
@@ -471,8 +468,7 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
draft_probs
,
double
threshold_single
=
1
,
double
threshold_acc
=
1
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
);
bool
deterministic
=
true
);
void
verify_tree_greedy
(
at
::
Tensor
predicts
,
// mutable
...
...
@@ -482,8 +478,7 @@ void verify_tree_greedy(
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
target_predict
,
int64_t
cuda_stream
=
0
);
at
::
Tensor
target_predict
);
void
reconstruct_indices_from_tree_mask
(
at
::
Tensor
tree_mask
,
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
20315697
...
...
@@ -2,7 +2,7 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
import
torch
from
sgl_kernel.utils
import
get_cuda_stream
,
is_arch_support_pdl
from
sgl_kernel.utils
import
is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
...
...
@@ -263,6 +263,10 @@ class FusedSetKVBufferArg:
cache_loc
:
torch
.
Tensor
def
_view_3d
(
x
,
head_size
):
return
x
.
view
(
x
.
shape
[
0
],
-
1
,
head_size
)
def
apply_rope_with_cos_sin_cache_inplace
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
@@ -317,31 +321,27 @@ def apply_rope_with_cos_sin_cache_inplace(
assert
a
.
v_scale
is
None
,
"v_scale is not yet supported"
assert
a
.
cache_loc
.
dtype
==
torch
.
int64
,
f
"
{
a
.
cache_loc
.
dtype
=
}
"
def
_view_3d
(
x
):
return
x
.
view
(
x
.
shape
[
0
],
-
1
,
head_size
)
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
.
default
(
_view_3d
(
query
),
_view_3d
(
key
),
_view_3d
(
query
),
_view_3d
(
key
),
_view_3d
(
query
,
head_size
),
_view_3d
(
key
,
head_size
),
_view_3d
(
query
,
head_size
),
_view_3d
(
key
,
head_size
),
cos_sin_cache
,
positions
.
long
(),
(
not
is_neox
),
enable_pdl
,
get_cuda_stream
(),
(
_view_3d
(
fused_set_kv_buffer_arg
.
value
)
_view_3d
(
fused_set_kv_buffer_arg
.
value
,
head_size
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
_view_3d
(
fused_set_kv_buffer_arg
.
k_buffer
)
_view_3d
(
fused_set_kv_buffer_arg
.
k_buffer
,
head_size
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
_view_3d
(
fused_set_kv_buffer_arg
.
v_buffer
)
_view_3d
(
fused_set_kv_buffer_arg
.
v_buffer
,
head_size
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
...
...
@@ -365,7 +365,7 @@ def downcast_fp8(
offset
:
int
=
0
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
downcast_fp8
(
k
,
v
,
k_out
,
v_out
,
k_scale
,
v_scale
,
loc
,
mult
,
offset
,
get_cuda_stream
()
k
,
v
,
k_out
,
v_out
,
k_scale
,
v_scale
,
loc
,
mult
,
offset
)
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
20315697
...
...
@@ -2,7 +2,7 @@ from typing import Optional, Tuple
import
torch
from
sgl_kernel.scalar_type
import
ScalarType
from
sgl_kernel.utils
import
_get_cache_buf
,
get_cuda_stream
from
sgl_kernel.utils
import
_get_cache_buf
def
awq_dequantize
(
...
...
@@ -60,7 +60,6 @@ def _bmm_fp8_internal(
B_scale
,
workspace_buffer
,
cublas_handle
,
get_cuda_stream
(),
)
...
...
sgl-kernel/python/sgl_kernel/speculative.py
View file @
20315697
import
torch
from
sgl_kernel.utils
import
get_cuda_stream
def
tree_speculative_sampling_target_only
(
...
...
@@ -33,7 +32,6 @@ def tree_speculative_sampling_target_only(
threshold_single
,
threshold_acc
,
deterministic
,
get_cuda_stream
(),
)
...
...
@@ -56,7 +54,6 @@ def verify_tree_greedy(
retrive_next_token
,
retrive_next_sibling
,
target_predict
,
get_cuda_stream
(),
)
...
...
sgl-kernel/python/sgl_kernel/utils.py
View file @
20315697
...
...
@@ -18,11 +18,6 @@ from typing import Dict, Tuple
import
torch
def
get_cuda_stream
()
->
int
:
return
torch
.
cuda
.
current_stream
().
cuda_stream
_cache_buf
:
Dict
[
Tuple
[
str
,
torch
.
device
],
torch
.
Tensor
]
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment