Unverified Commit 3c9740d2 authored by Zhaoyi Li's avatar Zhaoyi Li Committed by GitHub
Browse files

update variable naming and comments for rocm (#5299)

parent 2eb55770
...@@ -11,8 +11,8 @@ from vllm import _custom_ops as ops ...@@ -11,8 +11,8 @@ from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def vllm_scaled_fp8_quant( def vllm_scaled_fp8_quant(
......
...@@ -8,8 +8,8 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_ ...@@ -8,8 +8,8 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
@triton.jit @triton.jit
......
...@@ -9,8 +9,8 @@ from vllm import _custom_ops as ops ...@@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def vllm_per_token_quant_fp8( def vllm_per_token_quant_fp8(
......
...@@ -61,11 +61,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -61,11 +61,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def( m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"); "token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
/*
* From csrc/speculative
*/
m.def( m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " "verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
......
...@@ -7,8 +7,8 @@ from sgl_kernel import sgl_per_tensor_quant_fp8 ...@@ -7,8 +7,8 @@ from sgl_kernel import sgl_per_tensor_quant_fp8
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def sglang_scaled_fp8_quant( def sglang_scaled_fp8_quant(
......
...@@ -7,8 +7,8 @@ from sgl_kernel import sgl_per_token_quant_fp8 ...@@ -7,8 +7,8 @@ from sgl_kernel import sgl_per_token_quant_fp8
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def torch_per_token_quant_fp8(tensor, inv_scale): def torch_per_token_quant_fp8(tensor, inv_scale):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment