Commit 5b0a1c93 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix dense run error

parent 8f3d67b5
...@@ -28,6 +28,42 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL ...@@ -28,6 +28,42 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list) batchsize_forward_time: defaultdict = defaultdict(list)
class BatchDescriptor(NamedTuple):
"""
Batch descriptor for cudagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded
batch for cudagraph.
"""
num_tokens: int
uniform_decode: bool = False
"""
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
max_num_tokens: int,
chunk_idx: int) -> list[int]:
dp_size = len(num_tokens_across_dp_cpu)
local_size = [-1] * dp_size
for i in range(dp_size):
dp_tokens = num_tokens_across_dp_cpu[i]
local_size[i] = min(max_num_tokens,
dp_tokens - (max_num_tokens * chunk_idx))
if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done
return local_size
@dataclass @dataclass
class DPMetadata: class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor max_tokens_across_dp_cpu: torch.Tensor
......
...@@ -1313,10 +1313,10 @@ def inplace_fused_experts( ...@@ -1313,10 +1313,10 @@ def inplace_fused_experts(
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, is_act_and_mul, activation, is_act_and_mul,
apply_router_weight_on_input, use_fp8_w8a8, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_int4_w4a8,
use_mxfp4_w4a4, per_channel_quant, global_num_experts, use_mxfp4_w4a4, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, w1_bias, w2_bias) a2_scale, block_shape, w1_bias, w2_bias, use_nn_moe)
def inplace_fused_experts_fake(hidden_states: torch.Tensor, def inplace_fused_experts_fake(hidden_states: torch.Tensor,
...@@ -1331,6 +1331,7 @@ def inplace_fused_experts_fake(hidden_states: torch.Tensor, ...@@ -1331,6 +1331,7 @@ def inplace_fused_experts_fake(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1343,7 +1344,8 @@ def inplace_fused_experts_fake(hidden_states: torch.Tensor, ...@@ -1343,7 +1344,8 @@ def inplace_fused_experts_fake(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None, w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None: w2_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False) -> None:
pass pass
...@@ -1540,6 +1542,7 @@ def outplace_fused_experts( ...@@ -1540,6 +1542,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1553,13 +1556,14 @@ def outplace_fused_experts( ...@@ -1553,13 +1556,14 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None, #noqa: UP006 block_shape: Optional[List[int]] = None, #noqa: UP006
w1_bias: Optional[torch.Tensor] = None, w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation, hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_int4_w4a8, use_mxfp4_w4a4,
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias, use_nn_moe)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1634,6 +1638,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1634,6 +1638,7 @@ def fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
......
...@@ -547,6 +547,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -547,6 +547,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
use_nn_moe=use_nn_moe,
) )
def forward_cpu( def forward_cpu(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import ast import ast
from dataclasses import replace from dataclasses import replace
from typing import Optional from typing import Optional, Any
import numpy as np import numpy as np
......
...@@ -254,7 +254,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -254,7 +254,7 @@ class V1ZeroModelRunner(GPUModelRunner):
True) True)
last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int) last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor] input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
else:
update_req_indices = [] update_req_indices = []
input_ids_indices = [] input_ids_indices = []
token_idx = 0 token_idx = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment