Commit f137e58c authored by zhuwenwen's avatar zhuwenwen
Browse files

update List[int] and update num_rejected_tokens

parent 1b78ef29
...@@ -659,7 +659,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -659,7 +659,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None: use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights is not None or not mul_routed_weight assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1 assert topk_weights is None or topk_weights.stride(1) == 1
...@@ -1328,7 +1328,7 @@ def flashinfer_fused_moe_blockscale_fp8( ...@@ -1328,7 +1328,7 @@ def flashinfer_fused_moe_blockscale_fp8(
intermediate_size: int, intermediate_size: int,
expert_offset: int, expert_offset: int,
local_num_experts: int, local_num_experts: int,
block_shape: list[int], block_shape: List[int],
routed_scaling: float = 1.0) -> torch.Tensor: routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts assert top_k <= global_num_experts
...@@ -1381,7 +1381,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( ...@@ -1381,7 +1381,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
intermediate_size: int, intermediate_size: int,
expert_offset: int, expert_offset: int,
local_num_experts: int, local_num_experts: int,
block_shape: list[int], block_shape: List[int],
routed_scaling: float = 1.0) -> torch.Tensor: routed_scaling: float = 1.0) -> torch.Tensor:
return torch.empty_like(x) return torch.empty_like(x)
......
...@@ -55,7 +55,7 @@ class CommonAttentionMetadata: ...@@ -55,7 +55,7 @@ class CommonAttentionMetadata:
"""Total number of tokens in batch""" """Total number of tokens in batch"""
max_query_len: int max_query_len: int
"""Longest query in batch""" """Longest query in batch"""
num_rejected_tokens: list[int] = None num_rejected_tokens: list[int]
"""(batch_size,), record the rejected tokens number in cpu and gpu""" """(batch_size,), record the rejected tokens number in cpu and gpu"""
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment