Unverified Commit 20315697 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

move all get_stream in sgl_kernel to c++ to reduce the launch overhead (#12521)

parent c9db7911
...@@ -4,32 +4,20 @@ from typing import List, Optional, Tuple ...@@ -4,32 +4,20 @@ from typing import List, Optional, Tuple
import torch import torch
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu from sglang.srt.utils import is_hip, is_hpu, is_npu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
)
if not is_hpu(): if not is_hpu():
# ROCm does not use vllm custom allreduce try:
if use_vllm_custom_allreduce and not is_hip(): import sgl_kernel
try: except ImportError as e:
import vllm._C # noqa: F401 logger.warning("Failed to import from custom_ar with %r", e)
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
else:
try:
import sgl_kernel
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)
if not is_hip() and not is_npu(): if not is_hip() and not is_npu():
if use_vllm_custom_allreduce: custom_op = sgl_kernel.allreduce
custom_op = torch.ops._C_custom_ar
else:
custom_op = sgl_kernel.allreduce
# custom allreduce # custom allreduce
def init_custom_ar( def init_custom_ar(
......
...@@ -19,7 +19,6 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -19,7 +19,6 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_tensor_model_parallel_world_size
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -297,8 +296,10 @@ class FalconH1Config(PretrainedConfig): ...@@ -297,8 +296,10 @@ class FalconH1Config(PretrainedConfig):
@property @property
def mamba2_cache_params(self): def mamba2_cache_params(self):
from sglang.srt.layers.dp_attention import get_attention_tp_size
shape = Mamba2StateShape.create( shape = Mamba2StateShape.create(
tp_world_size=get_tensor_model_parallel_world_size(), tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_intermediate, intermediate_size=self.mamba_intermediate,
n_groups=self.mamba_n_groups, n_groups=self.mamba_n_groups,
num_heads=self.mamba_n_heads, num_heads=self.mamba_n_heads,
......
...@@ -20,7 +20,6 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -20,7 +20,6 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -273,6 +272,8 @@ class NemotronHConfig(PretrainedConfig): ...@@ -273,6 +272,8 @@ class NemotronHConfig(PretrainedConfig):
@property @property
def mamba2_cache_params(self) -> Mamba2CacheParams: def mamba2_cache_params(self) -> Mamba2CacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size
shape = Mamba2StateShape.create( shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(), tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_num_heads * self.mamba_head_dim, intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
......
...@@ -21,7 +21,6 @@ from transformers.modeling_rope_utils import rope_config_validation ...@@ -21,7 +21,6 @@ from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -277,6 +276,8 @@ class Qwen3NextConfig(PretrainedConfig): ...@@ -277,6 +276,8 @@ class Qwen3NextConfig(PretrainedConfig):
@property @property
def mamba2_cache_params(self) -> Mamba2CacheParams: def mamba2_cache_params(self) -> Mamba2CacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size
shape = Mamba2StateShape.create( shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(), tp_world_size=get_attention_tp_size(),
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads, intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
......
...@@ -21,24 +21,19 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as ...@@ -21,24 +21,19 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0 from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
try: try:
if ops.use_vllm_custom_allreduce and not _is_hip: # Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
# Use vLLM custom allreduce import sgl_kernel # noqa: F401
ops.meta_size()
else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401
custom_ar = True custom_ar = True
except Exception: except ImportError:
# For CPUs # For CPUs
custom_ar = False custom_ar = False
_is_cuda = is_cuda()
_is_hip = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -229,7 +229,6 @@ class Envs: ...@@ -229,7 +229,6 @@ class Envs:
SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False) SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False)
# vLLM dependencies (TODO: they have been deprecated, we can remove them safely) # vLLM dependencies (TODO: they have been deprecated, we can remove them safely)
USE_VLLM_CUSTOM_ALLREDUCE = EnvBool(False)
USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False) USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False)
USE_TRITON_W8A8_FP8_KERNEL = EnvBool(False) USE_TRITON_W8A8_FP8_KERNEL = EnvBool(False)
......
...@@ -303,6 +303,7 @@ def xpu_has_xmx_support(): ...@@ -303,6 +303,7 @@ def xpu_has_xmx_support():
return False return False
@lru_cache(maxsize=1)
def is_flashinfer_available(): def is_flashinfer_available():
""" """
Check whether flashinfer is available. Check whether flashinfer is available.
......
...@@ -52,8 +52,8 @@ make build ...@@ -52,8 +52,8 @@ make build
```cpp ```cpp
// We need def with schema here for torch.compile // We need def with schema here for torch.compile
m.def( m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
"cublas_handle, int cuda_stream) -> ()"); "int cublas_handle) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
``` ```
......
...@@ -90,13 +90,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -90,13 +90,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, bool enable_pdl, int cuda_stream, " "Tensor pos_ids, bool interleave, bool enable_pdl, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"); "Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
m.def( m.def(
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int " "downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, "
"mult, int offset, int cuda_stream) -> ()"); "int mult, int offset) -> ()");
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8); m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"); m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()");
...@@ -303,13 +303,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -303,13 +303,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, " "Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, " "float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()"); "bool deterministic) -> ()");
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only); m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
m.def( m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " "verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()"); "Tensor target_predict) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy); m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def( m.def(
...@@ -403,8 +403,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -403,8 +403,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From FlashInfer * From FlashInfer
*/ */
m.def( m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
"cublas_handle, int cuda_stream) -> ()", "int cublas_handle) -> ()",
{at::Tag::needs_fixed_stride_order}); {at::Tag::needs_fixed_stride_order});
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
......
...@@ -106,7 +106,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -106,7 +106,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def( m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " "verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()"); "Tensor target_predict) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy); m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def( m.def(
......
...@@ -150,14 +150,13 @@ void downcast_fp8( ...@@ -150,14 +150,13 @@ void downcast_fp8(
at::Tensor& v_scale, at::Tensor& v_scale,
at::Tensor& loc, at::Tensor& loc,
int64_t mult, int64_t mult,
int64_t offset, int64_t offset) {
int64_t cuda_stream) {
CHECK_INPUT(k); CHECK_INPUT(k);
CHECK_INPUT(v); CHECK_INPUT(v);
CHECK_INPUT(k_out); CHECK_INPUT(k_out);
CHECK_INPUT(v_out); CHECK_INPUT(v_out);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (k.scalar_type()) { switch (k.scalar_type()) {
case at::ScalarType::BFloat16: case at::ScalarType::BFloat16:
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream); downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
......
...@@ -28,7 +28,6 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -28,7 +28,6 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor pos_ids, at::Tensor pos_ids,
bool interleave, bool interleave,
bool enable_pdl, bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v, const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer, const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer, const std::optional<at::Tensor>& v_buffer,
...@@ -88,7 +87,7 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -88,7 +87,7 @@ void apply_rope_pos_ids_cos_sin_cache(
size_t k_rope_stride_n = k_rope.stride(0); size_t k_rope_stride_n = k_rope.stride(0);
size_t k_rope_stride_h = k_rope.stride(1); size_t k_rope_stride_h = k_rope.stride(1);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache // TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
// to avoid changing original code path; but this branch is feature-complete and should switch to this later // to avoid changing original code path; but this branch is feature-complete and should switch to this later
......
...@@ -27,8 +27,7 @@ void bmm_fp8( ...@@ -27,8 +27,7 @@ void bmm_fp8(
at::Tensor A_scale, at::Tensor A_scale,
at::Tensor B_scale, at::Tensor B_scale,
at::Tensor workspace_buffer, at::Tensor workspace_buffer,
int64_t cublas_handle, int64_t cublas_handle) {
int64_t cuda_stream) {
TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor"); TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor");
TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor"); TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor");
TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor"); TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor");
...@@ -51,7 +50,7 @@ void bmm_fp8( ...@@ -51,7 +50,7 @@ void bmm_fp8(
auto n = B.size(2); auto n = B.size(2);
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle); auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream); auto stream = at::cuda::getCurrentCUDAStream();
auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
workspace_buffer.data_ptr(), workspace_buffer.data_ptr(),
......
...@@ -328,8 +328,7 @@ void verify_tree_greedy( ...@@ -328,8 +328,7 @@ void verify_tree_greedy(
at::Tensor retrive_index, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling, at::Tensor retrive_next_sibling,
at::Tensor target_predict, at::Tensor target_predict) {
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates); CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index); CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token); CHECK_INPUT(retrive_next_token);
...@@ -389,7 +388,7 @@ void verify_tree_greedy( ...@@ -389,7 +388,7 @@ void verify_tree_greedy(
throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64)."); throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64).");
} }
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(batch_size); dim3 grid(batch_size);
dim3 block(1); dim3 block(1);
......
...@@ -42,8 +42,7 @@ void tree_speculative_sampling_target_only( ...@@ -42,8 +42,7 @@ void tree_speculative_sampling_target_only(
at::Tensor draft_probs, at::Tensor draft_probs,
double threshold_single, double threshold_single,
double threshold_acc, double threshold_acc,
bool deterministic = true, bool deterministic = true) {
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates); CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index); CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token); CHECK_INPUT(retrive_next_token);
...@@ -124,7 +123,7 @@ void tree_speculative_sampling_target_only( ...@@ -124,7 +123,7 @@ void tree_speculative_sampling_target_only(
CHECK_GE(threshold_acc, 0); CHECK_GE(threshold_acc, 0);
CHECK_GE(1, threshold_acc); CHECK_GE(1, threshold_acc);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int32_t, int64_t>( cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int32_t, int64_t>(
static_cast<int32_t*>(predicts.data_ptr()), static_cast<int32_t*>(predicts.data_ptr()),
static_cast<int32_t*>(accept_index.data_ptr()), static_cast<int32_t*>(accept_index.data_ptr()),
......
...@@ -152,7 +152,6 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -152,7 +152,6 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor pos_ids, at::Tensor pos_ids,
bool interleave, bool interleave,
bool enable_pdl, bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v, const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer, const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer, const std::optional<at::Tensor>& v_buffer,
...@@ -167,8 +166,7 @@ void downcast_fp8( ...@@ -167,8 +166,7 @@ void downcast_fp8(
at::Tensor& v_scale, at::Tensor& v_scale,
at::Tensor& loc, at::Tensor& loc,
int64_t mult, int64_t mult,
int64_t offset, int64_t offset);
int64_t cuda_stream);
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
...@@ -253,8 +251,7 @@ void bmm_fp8( ...@@ -253,8 +251,7 @@ void bmm_fp8(
at::Tensor A_scale, at::Tensor A_scale,
at::Tensor B_scale, at::Tensor B_scale,
at::Tensor workspace_buffer, at::Tensor workspace_buffer,
int64_t cublas_handle, int64_t cublas_handle);
int64_t cuda_stream);
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
...@@ -471,8 +468,7 @@ void tree_speculative_sampling_target_only( ...@@ -471,8 +468,7 @@ void tree_speculative_sampling_target_only(
at::Tensor draft_probs, at::Tensor draft_probs,
double threshold_single = 1, double threshold_single = 1,
double threshold_acc = 1, double threshold_acc = 1,
bool deterministic = true, bool deterministic = true);
int64_t cuda_stream = 0);
void verify_tree_greedy( void verify_tree_greedy(
at::Tensor predicts, // mutable at::Tensor predicts, // mutable
...@@ -482,8 +478,7 @@ void verify_tree_greedy( ...@@ -482,8 +478,7 @@ void verify_tree_greedy(
at::Tensor retrive_index, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling, at::Tensor retrive_next_sibling,
at::Tensor target_predict, at::Tensor target_predict);
int64_t cuda_stream = 0);
void reconstruct_indices_from_tree_mask( void reconstruct_indices_from_tree_mask(
at::Tensor tree_mask, at::Tensor tree_mask,
......
...@@ -2,7 +2,7 @@ from dataclasses import dataclass ...@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import List, Optional from typing import List, Optional
import torch import torch
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl from sgl_kernel.utils import is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
...@@ -263,6 +263,10 @@ class FusedSetKVBufferArg: ...@@ -263,6 +263,10 @@ class FusedSetKVBufferArg:
cache_loc: torch.Tensor cache_loc: torch.Tensor
def _view_3d(x, head_size):
return x.view(x.shape[0], -1, head_size)
def apply_rope_with_cos_sin_cache_inplace( def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
...@@ -317,31 +321,27 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -317,31 +321,27 @@ def apply_rope_with_cos_sin_cache_inplace(
assert a.v_scale is None, "v_scale is not yet supported" assert a.v_scale is None, "v_scale is not yet supported"
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}" assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
def _view_3d(x):
return x.view(x.shape[0], -1, head_size)
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
_view_3d(query), _view_3d(query, head_size),
_view_3d(key), _view_3d(key, head_size),
_view_3d(query), _view_3d(query, head_size),
_view_3d(key), _view_3d(key, head_size),
cos_sin_cache, cos_sin_cache,
positions.long(), positions.long(),
(not is_neox), (not is_neox),
enable_pdl, enable_pdl,
get_cuda_stream(),
( (
_view_3d(fused_set_kv_buffer_arg.value) _view_3d(fused_set_kv_buffer_arg.value, head_size)
if fused_set_kv_buffer_arg is not None if fused_set_kv_buffer_arg is not None
else None else None
), ),
( (
_view_3d(fused_set_kv_buffer_arg.k_buffer) _view_3d(fused_set_kv_buffer_arg.k_buffer, head_size)
if fused_set_kv_buffer_arg is not None if fused_set_kv_buffer_arg is not None
else None else None
), ),
( (
_view_3d(fused_set_kv_buffer_arg.v_buffer) _view_3d(fused_set_kv_buffer_arg.v_buffer, head_size)
if fused_set_kv_buffer_arg is not None if fused_set_kv_buffer_arg is not None
else None else None
), ),
...@@ -365,7 +365,7 @@ def downcast_fp8( ...@@ -365,7 +365,7 @@ def downcast_fp8(
offset: int = 0, offset: int = 0,
) -> None: ) -> None:
torch.ops.sgl_kernel.downcast_fp8( torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream() k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset
) )
......
...@@ -2,7 +2,7 @@ from typing import Optional, Tuple ...@@ -2,7 +2,7 @@ from typing import Optional, Tuple
import torch import torch
from sgl_kernel.scalar_type import ScalarType from sgl_kernel.scalar_type import ScalarType
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream from sgl_kernel.utils import _get_cache_buf
def awq_dequantize( def awq_dequantize(
...@@ -60,7 +60,6 @@ def _bmm_fp8_internal( ...@@ -60,7 +60,6 @@ def _bmm_fp8_internal(
B_scale, B_scale,
workspace_buffer, workspace_buffer,
cublas_handle, cublas_handle,
get_cuda_stream(),
) )
......
import torch import torch
from sgl_kernel.utils import get_cuda_stream
def tree_speculative_sampling_target_only( def tree_speculative_sampling_target_only(
...@@ -33,7 +32,6 @@ def tree_speculative_sampling_target_only( ...@@ -33,7 +32,6 @@ def tree_speculative_sampling_target_only(
threshold_single, threshold_single,
threshold_acc, threshold_acc,
deterministic, deterministic,
get_cuda_stream(),
) )
...@@ -56,7 +54,6 @@ def verify_tree_greedy( ...@@ -56,7 +54,6 @@ def verify_tree_greedy(
retrive_next_token, retrive_next_token,
retrive_next_sibling, retrive_next_sibling,
target_predict, target_predict,
get_cuda_stream(),
) )
......
...@@ -18,11 +18,6 @@ from typing import Dict, Tuple ...@@ -18,11 +18,6 @@ from typing import Dict, Tuple
import torch import torch
def get_cuda_stream() -> int:
return torch.cuda.current_stream().cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} _cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment