"tools/htmlify/to_xml.cpp" did not exist on "4c9f2715598b65312ba96b2b775a5dc1a862191d"
Unverified Commit 20315697 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

move all get_stream in sgl_kernel to c++ to reduce the launch overhead (#12521)

parent c9db7911
......@@ -4,32 +4,20 @@ from typing import List, Optional, Tuple
import torch
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
from sglang.srt.utils import is_hip, is_hpu, is_npu
logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
)
if not is_hpu():
# ROCm does not use vllm custom allreduce
if use_vllm_custom_allreduce and not is_hip():
try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
else:
try:
import sgl_kernel
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)
try:
import sgl_kernel
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)
if not is_hip() and not is_npu():
if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar
else:
custom_op = sgl_kernel.allreduce
custom_op = sgl_kernel.allreduce
# custom allreduce
def init_custom_ar(
......
......@@ -19,7 +19,6 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_tensor_model_parallel_world_size
logger = logging.get_logger(__name__)
......@@ -297,8 +296,10 @@ class FalconH1Config(PretrainedConfig):
@property
def mamba2_cache_params(self):
from sglang.srt.layers.dp_attention import get_attention_tp_size
shape = Mamba2StateShape.create(
tp_world_size=get_tensor_model_parallel_world_size(),
tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_intermediate,
n_groups=self.mamba_n_groups,
num_heads=self.mamba_n_heads,
......
......@@ -20,7 +20,6 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__)
......@@ -273,6 +272,8 @@ class NemotronHConfig(PretrainedConfig):
@property
def mamba2_cache_params(self) -> Mamba2CacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size
shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
......
......@@ -21,7 +21,6 @@ from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__)
......@@ -277,6 +276,8 @@ class Qwen3NextConfig(PretrainedConfig):
@property
def mamba2_cache_params(self) -> Mamba2CacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size
shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
......
......@@ -21,24 +21,19 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.environ import envs
from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
try:
if ops.use_vllm_custom_allreduce and not _is_hip:
# Use vLLM custom allreduce
ops.meta_size()
else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401
custom_ar = True
except Exception:
except ImportError:
# For CPUs
custom_ar = False
_is_cuda = is_cuda()
_is_hip = is_hip()
logger = logging.getLogger(__name__)
......
......@@ -229,7 +229,6 @@ class Envs:
SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False)
# vLLM dependencies (TODO: they have been deprecated, we can remove them safely)
USE_VLLM_CUSTOM_ALLREDUCE = EnvBool(False)
USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False)
USE_TRITON_W8A8_FP8_KERNEL = EnvBool(False)
......
......@@ -303,6 +303,7 @@ def xpu_has_xmx_support():
return False
@lru_cache(maxsize=1)
def is_flashinfer_available():
"""
Check whether flashinfer is available.
......
......@@ -52,8 +52,8 @@ make build
```cpp
// We need def with schema here for torch.compile
m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()");
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
"int cublas_handle) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
```
......
......@@ -90,13 +90,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, bool enable_pdl, int cuda_stream, "
"Tensor pos_ids, bool interleave, bool enable_pdl, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
m.def(
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int "
"mult, int offset, int cuda_stream) -> ()");
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, "
"int mult, int offset) -> ()");
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()");
......@@ -303,13 +303,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()");
"bool deterministic) -> ()");
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()");
"Tensor target_predict) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
......@@ -403,8 +403,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From FlashInfer
*/
m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()",
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, "
"int cublas_handle) -> ()",
{at::Tag::needs_fixed_stride_order});
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
......
......@@ -106,7 +106,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()");
"Tensor target_predict) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
......
......@@ -150,14 +150,13 @@ void downcast_fp8(
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream) {
int64_t offset) {
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(k_out);
CHECK_INPUT(v_out);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (k.scalar_type()) {
case at::ScalarType::BFloat16:
downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream);
......
......@@ -28,7 +28,6 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor pos_ids,
bool interleave,
bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer,
......@@ -88,7 +87,7 @@ void apply_rope_pos_ids_cos_sin_cache(
size_t k_rope_stride_n = k_rope.stride(0);
size_t k_rope_stride_h = k_rope.stride(1);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
......
......@@ -27,8 +27,7 @@ void bmm_fp8(
at::Tensor A_scale,
at::Tensor B_scale,
at::Tensor workspace_buffer,
int64_t cublas_handle,
int64_t cuda_stream) {
int64_t cublas_handle) {
TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor");
TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor");
TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor");
......@@ -51,7 +50,7 @@ void bmm_fp8(
auto n = B.size(2);
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
auto stream = at::cuda::getCurrentCUDAStream();
auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
workspace_buffer.data_ptr(),
......
......@@ -328,8 +328,7 @@ void verify_tree_greedy(
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor target_predict,
int64_t cuda_stream = 0) {
at::Tensor target_predict) {
CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token);
......@@ -389,7 +388,7 @@ void verify_tree_greedy(
throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64).");
}
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(batch_size);
dim3 block(1);
......
......@@ -42,8 +42,7 @@ void tree_speculative_sampling_target_only(
at::Tensor draft_probs,
double threshold_single,
double threshold_acc,
bool deterministic = true,
int64_t cuda_stream = 0) {
bool deterministic = true) {
CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token);
......@@ -124,7 +123,7 @@ void tree_speculative_sampling_target_only(
CHECK_GE(threshold_acc, 0);
CHECK_GE(1, threshold_acc);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int32_t, int64_t>(
static_cast<int32_t*>(predicts.data_ptr()),
static_cast<int32_t*>(accept_index.data_ptr()),
......
......@@ -152,7 +152,6 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor pos_ids,
bool interleave,
bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer,
......@@ -167,8 +166,7 @@ void downcast_fp8(
at::Tensor& v_scale,
at::Tensor& loc,
int64_t mult,
int64_t offset,
int64_t cuda_stream);
int64_t offset);
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
......@@ -253,8 +251,7 @@ void bmm_fp8(
at::Tensor A_scale,
at::Tensor B_scale,
at::Tensor workspace_buffer,
int64_t cublas_handle,
int64_t cuda_stream);
int64_t cublas_handle);
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
......@@ -471,8 +468,7 @@ void tree_speculative_sampling_target_only(
at::Tensor draft_probs,
double threshold_single = 1,
double threshold_acc = 1,
bool deterministic = true,
int64_t cuda_stream = 0);
bool deterministic = true);
void verify_tree_greedy(
at::Tensor predicts, // mutable
......@@ -482,8 +478,7 @@ void verify_tree_greedy(
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor target_predict,
int64_t cuda_stream = 0);
at::Tensor target_predict);
void reconstruct_indices_from_tree_mask(
at::Tensor tree_mask,
......
......@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import List, Optional
import torch
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
from sgl_kernel.utils import is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
......@@ -263,6 +263,10 @@ class FusedSetKVBufferArg:
cache_loc: torch.Tensor
def _view_3d(x, head_size):
return x.view(x.shape[0], -1, head_size)
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
......@@ -317,31 +321,27 @@ def apply_rope_with_cos_sin_cache_inplace(
assert a.v_scale is None, "v_scale is not yet supported"
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
def _view_3d(x):
return x.view(x.shape[0], -1, head_size)
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
_view_3d(query),
_view_3d(key),
_view_3d(query),
_view_3d(key),
_view_3d(query, head_size),
_view_3d(key, head_size),
_view_3d(query, head_size),
_view_3d(key, head_size),
cos_sin_cache,
positions.long(),
(not is_neox),
enable_pdl,
get_cuda_stream(),
(
_view_3d(fused_set_kv_buffer_arg.value)
_view_3d(fused_set_kv_buffer_arg.value, head_size)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.k_buffer)
_view_3d(fused_set_kv_buffer_arg.k_buffer, head_size)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.v_buffer)
_view_3d(fused_set_kv_buffer_arg.v_buffer, head_size)
if fused_set_kv_buffer_arg is not None
else None
),
......@@ -365,7 +365,7 @@ def downcast_fp8(
offset: int = 0,
) -> None:
torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset
)
......
......@@ -2,7 +2,7 @@ from typing import Optional, Tuple
import torch
from sgl_kernel.scalar_type import ScalarType
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
from sgl_kernel.utils import _get_cache_buf
def awq_dequantize(
......@@ -60,7 +60,6 @@ def _bmm_fp8_internal(
B_scale,
workspace_buffer,
cublas_handle,
get_cuda_stream(),
)
......
import torch
from sgl_kernel.utils import get_cuda_stream
def tree_speculative_sampling_target_only(
......@@ -33,7 +32,6 @@ def tree_speculative_sampling_target_only(
threshold_single,
threshold_acc,
deterministic,
get_cuda_stream(),
)
......@@ -56,7 +54,6 @@ def verify_tree_greedy(
retrive_next_token,
retrive_next_sibling,
target_predict,
get_cuda_stream(),
)
......
......@@ -18,11 +18,6 @@ from typing import Dict, Tuple
import torch
def get_cuda_stream() -> int:
return torch.cuda.current_stream().cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment