Commit 62f05dde authored by 王敏's avatar 王敏
Browse files

[feat]1.优化ep sequence parallel,区分主模型和mtp逻辑;2.ep sequence parallel添加cudagraph...

[feat]1.优化ep sequence parallel,区分主模型和mtp逻辑;2.ep sequence parallel添加cudagraph padding到tp_size;3.修复共享专家和deepep combine overlap
parent 2e1f5a46
...@@ -52,7 +52,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, ...@@ -52,7 +52,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
LayerBlockType, LazyLoader, common_broadcastable_dtype, LayerBlockType, LazyLoader, common_broadcastable_dtype,
cuda_device_count_stateless, get_cpu_memory, cuda_device_count_stateless, get_cpu_memory,
get_open_port, is_torch_equal_or_newer, random_uuid, get_open_port, is_torch_equal_or_newer, random_uuid,
resolve_obj_by_qualname) resolve_obj_by_qualname, round_up)
from vllm.utils import SUPPORT_TC from vllm.utils import SUPPORT_TC
# yapf: enable # yapf: enable
...@@ -4782,6 +4782,11 @@ class VllmConfig: ...@@ -4782,6 +4782,11 @@ class VllmConfig:
size for size in batch_size_capture_list size for size in batch_size_capture_list
if size <= max_num_tokens if size <= max_num_tokens
] ]
# add for ep sp
dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
# add for spec decode # add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0: if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
...@@ -4789,6 +4794,12 @@ class VllmConfig: ...@@ -4789,6 +4794,12 @@ class VllmConfig:
batch_size_capture_list)) batch_size_capture_list))
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list)) batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0] batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
else:
if ep_sp:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)
......
...@@ -256,7 +256,6 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -256,7 +256,6 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True, low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank, num_qps_per_rank=num_qps_per_rank,
allow_mnnvl=envs.VLLM_ALLOW_MNNVL,
) )
def get_handle(self, kwargs): def get_handle(self, kwargs):
......
...@@ -809,6 +809,21 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -809,6 +809,21 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_ids, apply_router_weight_on_input) topk_ids, apply_router_weight_on_input)
return output return output
_alt_stream: torch.cuda.Stream | None = None
def alt_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _alt_stream
# TODO: validate this works properly on ROCm platform.
if _alt_stream is None:
_alt_stream = torch.cuda.Stream()
return _alt_stream
@final @final
class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
...@@ -835,6 +850,10 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -835,6 +850,10 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.shared_experts = shared_experts self.shared_experts = shared_experts
if self.shared_experts is not None:
self.alt_stream = alt_stream()
self.alt_event = torch.cuda.Event()
# assert prepare_finalize.activation_format == \ # assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], ( # fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}." # f"{prepare_finalize.__class__.__name__}."
...@@ -979,14 +998,27 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -979,14 +998,27 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
) )
shared_output = None shared_output = None
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights, self.alt_event.record()
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
current_stream = torch.cuda.current_stream()
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
if hook is not None: with torch.cuda.stream(self.alt_stream):
hook() self.alt_stream.wait_event(self.alt_event)
hook = None
if self.prepare_finalize.activation_format == \
FusedMoEActivationFormat.BatchedExperts:
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if hook is not None:
hook()
self.alt_event.record()
current_stream.wait_event(self.alt_event)
if self.shared_experts is not None: if self.shared_experts is not None:
return (shared_output, output) return (shared_output, output)
......
...@@ -27,8 +27,9 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( ...@@ -27,8 +27,9 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
from vllm.utils import round_up from vllm.utils import round_up
try: try:
from lightop import m_grouped_w8a8_gemm_nt_masked, m_grouped_w8a8_gemm_nt_contig_asm, fuse_silu_mul_quant_ep, fuse_silu_mul_quant from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep, fuse_silu_mul_quant
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
from lightop import m_grouped_w8a8_gemm_nt_contig_asm
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
......
...@@ -181,7 +181,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -181,7 +181,6 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts) self.n_local_physical_experts)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
...@@ -1098,7 +1097,13 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1098,7 +1097,13 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.is_mtp_layer: if self.is_mtp_layer:
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = hidden_states.tensor_split(self.tp_size)[self.tp_rank] ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
...@@ -1106,6 +1111,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1106,6 +1111,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
...@@ -188,12 +188,6 @@ class EagleProposer: ...@@ -188,12 +188,6 @@ class EagleProposer:
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
......
...@@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors ...@@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, get_dtype_size, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up) is_pin_memory_available, round_up, round_down)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
...@@ -332,6 +332,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -332,6 +332,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.draft_probs : Optional[DraftProbs] = None self.draft_probs : Optional[DraftProbs] = None
self.ep_sp = False
self.dp_size = self.parallel_config.data_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size
self.enable_expert_parallel = self.parallel_config.enable_expert_parallel
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -1268,7 +1275,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1268,7 +1275,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for # TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND == 'naive': if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# Early exit. # Early exit.
return 0, None return 0, None
...@@ -1345,28 +1352,33 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1345,28 +1352,33 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata, spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# Use piecewise CUDA graphs. if self.ep_sp:
# Add padding to the batch size. num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
num_input_tokens = self.vllm_config.pad_for_cudagraph( if (self.use_cuda_graph
num_scheduled_tokens) and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else: else:
# Eager mode. if (self.use_cuda_graph
# Pad tokens to multiple of tensor_parallel_size when and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# enabled collective fusion for SP # Use piecewise CUDA graphs.
tp_size = self.vllm_config.parallel_config.tensor_parallel_size # Add padding to the batch size.
if self.compilation_config.pass_config. \ num_input_tokens = self.vllm_config.pad_for_cudagraph(
enable_sequence_parallelism and tp_size > 1: num_scheduled_tokens)
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else: else:
num_input_tokens = num_scheduled_tokens # Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # enabled collective fusion for SP
dp_size = self.vllm_config.parallel_config.data_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size if self.compilation_config.pass_config. \
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1: enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size) num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
...@@ -2103,12 +2115,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2103,12 +2115,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
skip_eplb: bool = False, skip_eplb: bool = False,
is_profile: bool = False, is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size if self.ep_sp:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size if num_tokens < self.tp_size:
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1: num_tokens = self.tp_size
num_tokens = round_up(num_tokens, tp_size)
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
...@@ -2117,10 +2128,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2117,10 +2128,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Set num_scheduled_tokens based on num_tokens and max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively # for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total. # has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs) num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs min_tokens_per_req = num_tokens // num_reqs
num_actual_tokens = num_tokens
if not is_profile and self.speculative_config is not None \ if not is_profile and self.speculative_config is not None \
and self.speculative_config.num_lookahead_slots > 0 \ and self.speculative_config.num_lookahead_slots > 0 \
...@@ -2128,8 +2141,17 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2128,8 +2141,17 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots) min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
if not self.ep_sp:
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else:
num_scheduled_tokens_list[-1] += num_tokens % min_tokens_per_req
assert sum(num_scheduled_tokens_list) == num_tokens assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, num_scheduled_tokens = np.array(num_scheduled_tokens_list,
...@@ -2153,7 +2175,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2153,7 +2175,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_lens=seq_lens, seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor, # seq_lens_tensor=seq_lens_tensor,
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=num_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=num_tokens, max_query_len=num_tokens,
num_speculative_tokens=num_speculative_tokens, num_speculative_tokens=num_speculative_tokens,
) )
...@@ -3186,28 +3208,33 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3186,28 +3208,33 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_decode_metadata, spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size if self.ep_sp:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1: if (self.use_cuda_graph
num_input_tokens = round_up(num_input_tokens, tp_size) and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else:
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
......
...@@ -112,12 +112,6 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -112,12 +112,6 @@ class V1ZeroEagleProposer(EagleProposer):
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
......
...@@ -424,28 +424,33 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -424,28 +424,33 @@ class V1ZeroModelRunner(GPUModelRunner):
spec_decode_metadata, spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# Use piecewise CUDA graphs. if self.ep_sp:
# Add padding to the batch size. num_input_tokens = round_up(num_scheduled_tokens, tp_size)
num_input_tokens = self.vllm_config.pad_for_cudagraph( if (self.use_cuda_graph
num_scheduled_tokens) and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else: else:
# Eager mode. if (self.use_cuda_graph
# Pad tokens to multiple of tensor_parallel_size when and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# enabled collective fusion for SP # Use piecewise CUDA graphs.
tp_size = self.vllm_config.parallel_config.tensor_parallel_size # Add padding to the batch size.
if self.compilation_config.pass_config. \ num_input_tokens = self.vllm_config.pad_for_cudagraph(
enable_sequence_parallelism and tp_size > 1: num_scheduled_tokens)
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else: else:
num_input_tokens = num_scheduled_tokens # Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # enabled collective fusion for SP
dp_size = self.vllm_config.parallel_config.data_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size if self.compilation_config.pass_config. \
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1: enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size) num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment