Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20315697
"tools/htmlify/to_xml.cpp" did not exist on "4c9f2715598b65312ba96b2b775a5dc1a862191d"
Unverified
Commit
20315697
authored
Nov 02, 2025
by
Lianmin Zheng
Committed by
GitHub
Nov 02, 2025
Browse files
move all get_stream in sgl_kernel to c++ to reduce the launch overhead (#12521)
parent
c9db7911
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
60 additions
and
93 deletions
+60
-93
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+7
-19
python/sglang/srt/configs/falcon_h1.py
python/sglang/srt/configs/falcon_h1.py
+3
-2
python/sglang/srt/configs/nemotron_h.py
python/sglang/srt/configs/nemotron_h.py
+2
-1
python/sglang/srt/configs/qwen3_next.py
python/sglang/srt/configs/qwen3_next.py
+2
-1
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+8
-13
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+0
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+1
-0
sgl-kernel/README.md
sgl-kernel/README.md
+2
-2
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+7
-7
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+1
-1
sgl-kernel/csrc/elementwise/cast.cu
sgl-kernel/csrc/elementwise/cast.cu
+2
-3
sgl-kernel/csrc/elementwise/rope.cu
sgl-kernel/csrc/elementwise/rope.cu
+1
-2
sgl-kernel/csrc/gemm/bmm_fp8.cu
sgl-kernel/csrc/gemm/bmm_fp8.cu
+2
-3
sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/csrc/speculative/eagle_utils.cu
+2
-3
sgl-kernel/csrc/speculative/speculative_sampling.cu
sgl-kernel/csrc/speculative/speculative_sampling.cu
+2
-3
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+4
-9
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+13
-13
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+1
-2
sgl-kernel/python/sgl_kernel/speculative.py
sgl-kernel/python/sgl_kernel/speculative.py
+0
-3
sgl-kernel/python/sgl_kernel/utils.py
sgl-kernel/python/sgl_kernel/utils.py
+0
-5
No files found.
python/sglang/srt/_custom_ops.py
View file @
20315697
...
...
@@ -4,32 +4,20 @@ from typing import List, Optional, Tuple
import
torch
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
,
is_npu
from
sglang.srt.utils
import
is_hip
,
is_hpu
,
is_npu
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
get_bool_env_var
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
"false"
)
if
not
is_hpu
():
# ROCm does not use vllm custom allreduce
if
use_vllm_custom_allreduce
and
not
is_hip
():
try
:
import
vllm._C
# noqa: F401
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
else
:
try
:
import
sgl_kernel
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
try
:
import
sgl_kernel
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
not
is_hip
()
and
not
is_npu
():
if
use_vllm_custom_allreduce
:
custom_op
=
torch
.
ops
.
_C_custom_ar
else
:
custom_op
=
sgl_kernel
.
allreduce
custom_op
=
sgl_kernel
.
allreduce
# custom allreduce
def
init_custom_ar
(
...
...
python/sglang/srt/configs/falcon_h1.py
View file @
20315697
...
...
@@ -19,7 +19,6 @@ from transformers.configuration_utils import PretrainedConfig
from
transformers.utils
import
logging
from
sglang.srt.configs.mamba_utils
import
Mamba2CacheParams
,
Mamba2StateShape
from
sglang.srt.layers.dp_attention
import
get_tensor_model_parallel_world_size
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -297,8 +296,10 @@ class FalconH1Config(PretrainedConfig):
@
property
def
mamba2_cache_params
(
self
):
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
shape
=
Mamba2StateShape
.
create
(
tp_world_size
=
get_ten
sor_model_parallel_world
_size
(),
tp_world_size
=
get_
at
ten
tion_tp
_size
(),
intermediate_size
=
self
.
mamba_intermediate
,
n_groups
=
self
.
mamba_n_groups
,
num_heads
=
self
.
mamba_n_heads
,
...
...
python/sglang/srt/configs/nemotron_h.py
View file @
20315697
...
...
@@ -20,7 +20,6 @@ from transformers.configuration_utils import PretrainedConfig
from
transformers.utils
import
logging
from
sglang.srt.configs.mamba_utils
import
Mamba2CacheParams
,
Mamba2StateShape
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -273,6 +272,8 @@ class NemotronHConfig(PretrainedConfig):
@
property
def
mamba2_cache_params
(
self
)
->
Mamba2CacheParams
:
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
shape
=
Mamba2StateShape
.
create
(
tp_world_size
=
get_attention_tp_size
(),
intermediate_size
=
self
.
mamba_num_heads
*
self
.
mamba_head_dim
,
...
...
python/sglang/srt/configs/qwen3_next.py
View file @
20315697
...
...
@@ -21,7 +21,6 @@ from transformers.modeling_rope_utils import rope_config_validation
from
transformers.utils
import
logging
from
sglang.srt.configs.mamba_utils
import
Mamba2CacheParams
,
Mamba2StateShape
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -277,6 +276,8 @@ class Qwen3NextConfig(PretrainedConfig):
@
property
def
mamba2_cache_params
(
self
)
->
Mamba2CacheParams
:
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
shape
=
Mamba2StateShape
.
create
(
tp_world_size
=
get_attention_tp_size
(),
intermediate_size
=
self
.
linear_value_head_dim
*
self
.
linear_num_value_heads
,
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
20315697
...
...
@@ -21,24 +21,19 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from
sglang.srt.environ
import
envs
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
log_info_on_rank0
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
# Use vLLM custom allreduce
ops
.
meta_size
()
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
# noqa: F401
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
# noqa: F401
custom_ar
=
True
except
Exception
:
except
ImportError
:
# For CPUs
custom_ar
=
False
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/environ.py
View file @
20315697
...
...
@@ -229,7 +229,6 @@ class Envs:
SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK
=
EnvBool
(
False
)
# vLLM dependencies (TODO: they have been deprecated, we can remove them safely)
USE_VLLM_CUSTOM_ALLREDUCE
=
EnvBool
(
False
)
USE_VLLM_CUTLASS_W8A8_FP8_KERNEL
=
EnvBool
(
False
)
USE_TRITON_W8A8_FP8_KERNEL
=
EnvBool
(
False
)
...
...
python/sglang/srt/utils/common.py
View file @
20315697
...
...
@@ -303,6 +303,7 @@ def xpu_has_xmx_support():
return
False
@
lru_cache
(
maxsize
=
1
)
def
is_flashinfer_available
():
"""
Check whether flashinfer is available.
...
...
sgl-kernel/README.md
View file @
20315697
...
...
@@ -52,8 +52,8 @@ make build
```
cpp
// We need def with schema here for torch.compile
m
.
def
(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer,
int
"
"cublas_handle
, int cuda_stream
) -> ()"
);
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
"
int
cublas_handle) -> ()"
);
m
.
impl
(
"bmm_fp8"
,
torch
::
kCUDA
,
&
bmm_fp8
);
```
...
...
sgl-kernel/csrc/common_extension.cc
View file @
20315697
...
...
@@ -90,13 +90,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, bool enable_pdl,
int cuda_stream,
"
"Tensor pos_ids, bool interleave, bool enable_pdl, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"
);
m
.
impl
(
"apply_rope_pos_ids_cos_sin_cache"
,
torch
::
kCUDA
,
&
apply_rope_pos_ids_cos_sin_cache
);
m
.
def
(
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc,
int
"
"mult, int offset
, int cuda_stream
) -> ()"
);
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, "
"
int
mult, int offset) -> ()"
);
m
.
impl
(
"downcast_fp8"
,
torch
::
kCUDA
,
&
downcast_fp8
);
m
.
def
(
"copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"
);
...
...
@@ -303,13 +303,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"bool deterministic
, int cuda_stream
) -> ()"
);
"bool deterministic) -> ()"
);
m
.
impl
(
"tree_speculative_sampling_target_only"
,
torch
::
kCUDA
,
&
tree_speculative_sampling_target_only
);
m
.
def
(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict
, int cuda_stream
) -> ()"
);
"Tensor target_predict) -> ()"
);
m
.
impl
(
"verify_tree_greedy"
,
torch
::
kCUDA
,
&
verify_tree_greedy
);
m
.
def
(
...
...
@@ -403,8 +403,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From FlashInfer
*/
m
.
def
(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer,
int
"
"cublas_handle
, int cuda_stream
) -> ()"
,
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
"
int
cublas_handle) -> ()"
,
{
at
::
Tag
::
needs_fixed_stride_order
});
m
.
impl
(
"bmm_fp8"
,
torch
::
kCUDA
,
&
bmm_fp8
);
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
20315697
...
...
@@ -106,7 +106,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m
.
def
(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict
, int cuda_stream
) -> ()"
);
"Tensor target_predict) -> ()"
);
m
.
impl
(
"verify_tree_greedy"
,
torch
::
kCUDA
,
&
verify_tree_greedy
);
m
.
def
(
...
...
sgl-kernel/csrc/elementwise/cast.cu
View file @
20315697
...
...
@@ -150,14 +150,13 @@ void downcast_fp8(
at
::
Tensor
&
v_scale
,
at
::
Tensor
&
loc
,
int64_t
mult
,
int64_t
offset
,
int64_t
cuda_stream
)
{
int64_t
offset
)
{
CHECK_INPUT
(
k
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
k_out
);
CHECK_INPUT
(
v_out
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
switch
(
k
.
scalar_type
())
{
case
at
::
ScalarType
::
BFloat16
:
downcast_fp8_impl
<
__nv_bfloat16
>
(
k
,
v
,
k_out
,
v_out
,
k_scale
,
v_scale
,
loc
,
mult
,
offset
,
stream
);
...
...
sgl-kernel/csrc/elementwise/rope.cu
View file @
20315697
...
...
@@ -28,7 +28,6 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
pos_ids
,
bool
interleave
,
bool
enable_pdl
,
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
v_buffer
,
...
...
@@ -88,7 +87,7 @@ void apply_rope_pos_ids_cos_sin_cache(
size_t
k_rope_stride_n
=
k_rope
.
stride
(
0
);
size_t
k_rope_stride_h
=
k_rope
.
stride
(
1
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
q
.
scalar_type
(),
c_type
,
[
&
]
{
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
...
...
sgl-kernel/csrc/gemm/bmm_fp8.cu
View file @
20315697
...
...
@@ -27,8 +27,7 @@ void bmm_fp8(
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
)
{
int64_t
cublas_handle
)
{
TORCH_CHECK
(
A
.
is_cuda
(),
"A must be a CUDA tensor"
);
TORCH_CHECK
(
B
.
is_cuda
(),
"B must be a CUDA tensor"
);
TORCH_CHECK
(
D
.
is_cuda
(),
"D must be a CUDA tensor"
);
...
...
@@ -51,7 +50,7 @@ void bmm_fp8(
auto
n
=
B
.
size
(
2
);
auto
lt_handle
=
reinterpret_cast
<
cublasLtHandle_t
>
(
cublas_handle
);
auto
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
auto
status
=
flashinfer
::
bmm_fp8
::
bmm_fp8_internal_cublaslt
(
workspace_buffer
.
data_ptr
(),
...
...
sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
20315697
...
...
@@ -328,8 +328,7 @@ void verify_tree_greedy(
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
target_predict
,
int64_t
cuda_stream
=
0
)
{
at
::
Tensor
target_predict
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
...
...
@@ -389,7 +388,7 @@ void verify_tree_greedy(
throw
std
::
runtime_error
(
"Expected 'target_predict' to be of type long (torch.int64)."
);
}
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
dim3
grid
(
batch_size
);
dim3
block
(
1
);
...
...
sgl-kernel/csrc/speculative/speculative_sampling.cu
View file @
20315697
...
...
@@ -42,8 +42,7 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
draft_probs
,
double
threshold_single
,
double
threshold_acc
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
)
{
bool
deterministic
=
true
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
...
...
@@ -124,7 +123,7 @@ void tree_speculative_sampling_target_only(
CHECK_GE
(
threshold_acc
,
0
);
CHECK_GE
(
1
,
threshold_acc
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_s
tream
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAS
tream
(
);
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int32_t
,
int64_t
>
(
static_cast
<
int32_t
*>
(
predicts
.
data_ptr
()),
static_cast
<
int32_t
*>
(
accept_index
.
data_ptr
()),
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
20315697
...
...
@@ -152,7 +152,6 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
pos_ids
,
bool
interleave
,
bool
enable_pdl
,
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
v_buffer
,
...
...
@@ -167,8 +166,7 @@ void downcast_fp8(
at
::
Tensor
&
v_scale
,
at
::
Tensor
&
loc
,
int64_t
mult
,
int64_t
offset
,
int64_t
cuda_stream
);
int64_t
offset
);
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
concat_mla_k
(
torch
::
Tensor
k
,
torch
::
Tensor
k_nope
,
torch
::
Tensor
k_rope
);
...
...
@@ -253,8 +251,7 @@ void bmm_fp8(
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
int64_t
cublas_handle
);
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
);
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
torch
::
Tensor
const
&
mat_b
);
...
...
@@ -471,8 +468,7 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
draft_probs
,
double
threshold_single
=
1
,
double
threshold_acc
=
1
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
);
bool
deterministic
=
true
);
void
verify_tree_greedy
(
at
::
Tensor
predicts
,
// mutable
...
...
@@ -482,8 +478,7 @@ void verify_tree_greedy(
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
target_predict
,
int64_t
cuda_stream
=
0
);
at
::
Tensor
target_predict
);
void
reconstruct_indices_from_tree_mask
(
at
::
Tensor
tree_mask
,
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
20315697
...
...
@@ -2,7 +2,7 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
import
torch
from
sgl_kernel.utils
import
get_cuda_stream
,
is_arch_support_pdl
from
sgl_kernel.utils
import
is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
...
...
@@ -263,6 +263,10 @@ class FusedSetKVBufferArg:
cache_loc
:
torch
.
Tensor
def
_view_3d
(
x
,
head_size
):
return
x
.
view
(
x
.
shape
[
0
],
-
1
,
head_size
)
def
apply_rope_with_cos_sin_cache_inplace
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
@@ -317,31 +321,27 @@ def apply_rope_with_cos_sin_cache_inplace(
assert
a
.
v_scale
is
None
,
"v_scale is not yet supported"
assert
a
.
cache_loc
.
dtype
==
torch
.
int64
,
f
"
{
a
.
cache_loc
.
dtype
=
}
"
def
_view_3d
(
x
):
return
x
.
view
(
x
.
shape
[
0
],
-
1
,
head_size
)
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
.
default
(
_view_3d
(
query
),
_view_3d
(
key
),
_view_3d
(
query
),
_view_3d
(
key
),
_view_3d
(
query
,
head_size
),
_view_3d
(
key
,
head_size
),
_view_3d
(
query
,
head_size
),
_view_3d
(
key
,
head_size
),
cos_sin_cache
,
positions
.
long
(),
(
not
is_neox
),
enable_pdl
,
get_cuda_stream
(),
(
_view_3d
(
fused_set_kv_buffer_arg
.
value
)
_view_3d
(
fused_set_kv_buffer_arg
.
value
,
head_size
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
_view_3d
(
fused_set_kv_buffer_arg
.
k_buffer
)
_view_3d
(
fused_set_kv_buffer_arg
.
k_buffer
,
head_size
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
_view_3d
(
fused_set_kv_buffer_arg
.
v_buffer
)
_view_3d
(
fused_set_kv_buffer_arg
.
v_buffer
,
head_size
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
...
...
@@ -365,7 +365,7 @@ def downcast_fp8(
offset
:
int
=
0
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
downcast_fp8
(
k
,
v
,
k_out
,
v_out
,
k_scale
,
v_scale
,
loc
,
mult
,
offset
,
get_cuda_stream
()
k
,
v
,
k_out
,
v_out
,
k_scale
,
v_scale
,
loc
,
mult
,
offset
)
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
20315697
...
...
@@ -2,7 +2,7 @@ from typing import Optional, Tuple
import
torch
from
sgl_kernel.scalar_type
import
ScalarType
from
sgl_kernel.utils
import
_get_cache_buf
,
get_cuda_stream
from
sgl_kernel.utils
import
_get_cache_buf
def
awq_dequantize
(
...
...
@@ -60,7 +60,6 @@ def _bmm_fp8_internal(
B_scale
,
workspace_buffer
,
cublas_handle
,
get_cuda_stream
(),
)
...
...
sgl-kernel/python/sgl_kernel/speculative.py
View file @
20315697
import
torch
from
sgl_kernel.utils
import
get_cuda_stream
def
tree_speculative_sampling_target_only
(
...
...
@@ -33,7 +32,6 @@ def tree_speculative_sampling_target_only(
threshold_single
,
threshold_acc
,
deterministic
,
get_cuda_stream
(),
)
...
...
@@ -56,7 +54,6 @@ def verify_tree_greedy(
retrive_next_token
,
retrive_next_sibling
,
target_predict
,
get_cuda_stream
(),
)
...
...
sgl-kernel/python/sgl_kernel/utils.py
View file @
20315697
...
...
@@ -18,11 +18,6 @@ from typing import Dict, Tuple
import
torch
def
get_cuda_stream
()
->
int
:
return
torch
.
cuda
.
current_stream
().
cuda_stream
_cache_buf
:
Dict
[
Tuple
[
str
,
torch
.
device
],
torch
.
Tensor
]
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment