"tests/vscode:/vscode.git/clone" did not exist on "c038b71fc92ffb7cd03a88e3a45f33c1d2ce4504"
Unverified Commit 3c9740d2 authored by Zhaoyi Li's avatar Zhaoyi Li Committed by GitHub
Browse files

update variable naming and comments for rocm (#5299)

parent 2eb55770
......@@ -11,8 +11,8 @@ from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def vllm_scaled_fp8_quant(
......
......@@ -8,8 +8,8 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
@triton.jit
......
......@@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def vllm_per_token_quant_fp8(
......
......@@ -61,11 +61,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
/*
* From csrc/speculative
*/
m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
......
......@@ -7,8 +7,8 @@ from sgl_kernel import sgl_per_tensor_quant_fp8
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def sglang_scaled_fp8_quant(
......
......@@ -7,8 +7,8 @@ from sgl_kernel import sgl_per_token_quant_fp8
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def torch_per_token_quant_fp8(tensor, inv_scale):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment