Unverified Commit 110e0066 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Reorganize python source files in sgl-kernel with multiple files (#4027)

parent 6b45a21d
......@@ -9,105 +9,37 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
mode=ctypes.RTLD_GLOBAL,
)
from sgl_kernel.ops.activation import (
apply_rope_with_cos_sin_cache_inplace,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
silu_and_mul,
)
from sgl_kernel.ops.allreduce import *
from sgl_kernel.ops.attention import lightning_attention_decode
from sgl_kernel.ops.gemm import (
bmm_fp8,
cublas_grouped_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
int8_scaled_mm,
sgl_per_token_group_quant_fp8,
)
from sgl_kernel.ops.moe import moe_align_block_size
from sgl_kernel.ops.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
top_p_sampling_from_probs,
)
from sgl_kernel.ops.speculative import (
build_tree_kernel,
build_tree_kernel_efficient,
tree_speculative_sampling_target_only,
)
from sgl_kernel.version import __version__
if torch.version.cuda:
from sgl_kernel.ops import (
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,
build_tree_kernel,
build_tree_kernel_efficient,
cublas_grouped_gemm,
custom_dispose,
custom_reduce,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
lightning_attention_decode,
min_p_sampling_from_probs,
moe_align_block_size,
register_graph_buffers,
rmsnorm,
sampling_scaling_penalties,
sgl_per_token_group_quant_fp8,
silu_and_mul,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
)
else:
assert torch.version.hip
from sgl_kernel.ops import (
all_reduce_reg,
all_reduce_unreg,
allocate_meta_buffer,
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,
dispose,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
get_meta_buffer_ipc_handle,
init_custom_ar,
int8_scaled_mm,
lightning_attention_decode,
meta_size,
min_p_sampling_from_probs,
moe_align_block_size,
register_buffer,
register_graph_buffers,
rmsnorm,
sampling_scaling_penalties,
silu_and_mul,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
__all__ = [
"__version__",
"apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8",
"cublas_grouped_gemm",
"custom_dispose",
"custom_reduce",
"build_tree_kernel_efficient",
"build_tree_kernel",
"fp8_blockwise_scaled_mm",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
"init_custom_reduce",
"int8_scaled_mm",
"lightning_attention_decode",
"min_p_sampling_from_probs",
"moe_align_block_size",
"register_graph_buffers",
"rmsnorm",
"sampling_scaling_penalties",
"sgl_per_token_group_quant_fp8",
"silu_and_mul",
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
"top_p_renorm_prob",
"tree_speculative_sampling_target_only",
]
This diff is collapsed.
from typing import Optional
import sgl_kernel.ops._kernels
import torch
from sgl_kernel.ops.utils import get_cuda_stream
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, get_cuda_stream())
return out
def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps)
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernels.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
return out
def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
input, residual, weight, eps, get_cuda_stream()
)
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
), f"{input.shape[:-1]} != {output.shape[:-1]}"
assert (
input.shape[-1] == 2 * output.shape[-1]
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernels.silu_and_mul(out, input, get_cuda_stream())
return out
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, get_cuda_stream())
return out
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernels.gelu_and_mul(out, input, get_cuda_stream())
return out
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
positions = positions.int()
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache,
pos_ids=positions,
interleave=(not is_neox),
cuda_stream=get_cuda_stream(),
)
from typing import List, Tuple
import sgl_kernel.ops._kernels
import torch
if torch.version.hip is not None:
# ROCM custom allreduce
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return torch.ops.sgl_kernels.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops.sgl_kernels.dispose(fa)
def meta_size() -> int:
return torch.ops.sgl_kernels.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
else:
# TRTLLM custom allreduce
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
return torch.ops.sgl_kernels.init_custom_ar(
rank_id,
num_devices,
rank_data,
buffers,
tmp_buffers,
barrier_in,
barrier_out,
)
def custom_dispose(fa):
torch.ops.sgl_kernels.dispose(fa)
def custom_reduce(fa, inp, out):
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
def get_graph_buffer_ipc_meta(fa):
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa, handles, offsets):
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
import sgl_kernel.ops._kernels
import torch
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch.ops.sgl_kernels.lightning_attention_decode(
q, k, v, past_kv, slope, output, new_kv
)
from typing import List, Optional
import sgl_kernel.ops._kernels
import torch
from sgl_kernel.ops.utils import _get_cache_buf, get_cuda_stream
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernels.int8_scaled_mm(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
)
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernels.fp8_scaled_mm(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def _bmm_fp8_internal(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernels.bmm_fp8(
A,
B,
D,
A_scale,
B_scale,
workspace_buffer,
cublas_handle,
get_cuda_stream(),
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty(
(A.shape[0], A.shape[1], B.shape[2]),
device=A.device,
dtype=dtype,
)
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
return out
def sgl_per_token_group_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
fp8_min: float,
fp8_max: float,
) -> None:
torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
)
def cublas_grouped_gemm(
inputs: List[torch.Tensor],
weights: List[torch.Tensor],
outputs: List[torch.Tensor],
out_dtype: torch.dtype,
) -> None:
assert (
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
), "Inputs/weights/outputs should not be empty!"
cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernels.cublas_grouped_gemm(
inputs,
weights,
outputs,
out_dtype,
cublas_handle,
get_cuda_stream(),
)
import sgl_kernel.ops._kernels
import torch
def moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
):
torch.ops.sgl_kernels.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
)
from typing import Optional, Tuple, Union
import sgl_kernel.ops._kernels
import torch
from sgl_kernel.ops.utils import _to_tensor_scalar_tuple, get_cuda_stream
def _top_k_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
probs,
renorm_probs,
maybe_top_k_arr,
top_k_val,
get_cuda_stream(),
)
return renorm_probs
def top_k_renorm_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
top_k_renorm_prob = top_k_renorm_probs
def _top_p_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
) -> torch.Tensor:
probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernels.top_p_renorm_probs(
probs,
renorm_probs,
maybe_top_p_arr,
top_p_val,
get_cuda_stream(),
)
return renorm_probs
def top_p_renorm_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
top_p_renorm_prob = top_p_renorm_probs
def _top_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernels.top_p_sampling_from_probs(
probs,
uniform_samples,
samples,
success,
maybe_top_p_arr,
top_p_val,
deterministic,
get_cuda_stream(),
)
return samples, success
def top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_p_sampling_from_probs_internal(
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
)
def _top_k_top_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
samples,
success,
maybe_top_k_arr,
top_k_val,
maybe_top_p_arr,
top_p_val,
deterministic,
get_cuda_stream(),
)
return samples, success
def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if filter_apply_order == "top_k_first":
renorm_probs = top_k_renorm_probs(probs, top_k)
return top_p_sampling_from_probs(
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
)
elif filter_apply_order == "joint":
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
uniform_samples,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
def _min_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_min_p_arr = (
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernels.min_p_sampling_from_probs(
probs,
uniform_samples,
samples,
maybe_min_p_arr,
min_p_val,
deterministic,
get_cuda_stream(),
)
return samples
def min_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
min_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False,
) -> torch.Tensor:
if uniform_samples.dim() == 2:
# Take the first row (round) of uniform_samples
uniform_samples = uniform_samples[0]
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _min_p_sampling_from_probs_internal(
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
)
import sgl_kernel.ops._kernels
import torch
from sgl_kernel.ops.utils import get_cuda_stream
def tree_speculative_sampling_target_only(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
uniform_samples: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
deterministic: bool = True,
) -> None:
torch.ops.sgl_kernels.tree_speculative_sampling_target_only(
predicts,
accept_index,
accept_token_num,
candidates,
retrive_index,
retrive_next_token,
retrive_next_sibling,
uniform_samples,
target_probs,
draft_probs,
deterministic,
get_cuda_stream(),
)
def build_tree_kernel_efficient(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernels.build_tree_kernel_efficient(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
depth,
draft_token_num,
)
def build_tree_kernel(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernels.build_tree_kernel(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token_num,
)
......@@ -18,8 +18,8 @@ from typing import Dict, Tuple
import torch
def _get_cuda_stream(device: torch.device) -> int:
return torch.cuda.current_stream(device).cuda_stream
def get_cuda_stream() -> int:
return torch.cuda.current_stream().cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
......
......@@ -7,9 +7,9 @@ import unittest
from typing import Any, List, Optional
import ray
import sgl_kernel.ops.allreduce as custom_ops
import torch
import torch.distributed as dist
from sgl_kernel import ops as custom_ops
from torch.distributed import ProcessGroup
from vllm import _custom_ops as vllm_ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment