Commit a0be38cb authored by yangql's avatar yangql
Browse files

updata data

parents df03e33b fd894e48
...@@ -52,7 +52,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, ...@@ -52,7 +52,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
LayerBlockType, LazyLoader, common_broadcastable_dtype, LayerBlockType, LazyLoader, common_broadcastable_dtype,
cuda_device_count_stateless, get_cpu_memory, cuda_device_count_stateless, get_cpu_memory,
get_open_port, is_torch_equal_or_newer, random_uuid, get_open_port, is_torch_equal_or_newer, random_uuid,
resolve_obj_by_qualname) resolve_obj_by_qualname, round_up)
from vllm.utils import SUPPORT_TC from vllm.utils import SUPPORT_TC
# yapf: enable # yapf: enable
...@@ -4778,6 +4778,11 @@ class VllmConfig: ...@@ -4778,6 +4778,11 @@ class VllmConfig:
if size <= max_num_tokens if size <= max_num_tokens
] ]
# add for ep sp
dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
# add for spec decode # add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0: if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
mtp_batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots), mtp_batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
...@@ -4785,6 +4790,12 @@ class VllmConfig: ...@@ -4785,6 +4790,12 @@ class VllmConfig:
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list)) batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0] batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
else:
if ep_sp:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)
......
...@@ -52,7 +52,7 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -52,7 +52,7 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.ll_prepare_finalize return self.ll_prepare_finalize
else: else:
# print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens) # print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens)
return self.ht__prepare_finalize return self.ht_prepare_finalize
#return self.ht_prepare_finalize #return self.ht_prepare_finalize
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
......
...@@ -812,6 +812,21 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -812,6 +812,21 @@ class FusedMoEModularKernel(torch.nn.Module):
return output return output
_alt_stream: torch.cuda.Stream | None = None
def alt_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _alt_stream
# TODO: validate this works properly on ROCm platform.
if _alt_stream is None:
_alt_stream = torch.cuda.Stream()
return _alt_stream
@final @final
class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
""" """
...@@ -842,6 +857,10 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -842,6 +857,10 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self.fused_experts_ll = experts_ll self.fused_experts_ll = experts_ll
self.shared_experts = shared_experts self.shared_experts = shared_experts
if self.shared_experts is not None:
self.alt_stream = alt_stream()
self.alt_event = torch.cuda.Event()
# assert prepare_finalize.activation_format == \ # assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], ( # fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}." # f"{prepare_finalize.__class__.__name__}."
...@@ -870,6 +889,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -870,6 +889,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
**_
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -1007,14 +1027,27 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -1007,14 +1027,27 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
) )
shared_output = None shared_output = None
hook = prepare_finalize.finalize_async(output, fused_out, topk_weights, self.alt_event.record()
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
current_stream = torch.cuda.current_stream()
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
self.alt_stream.wait_event(self.alt_event)
hook = None
if prepare_finalize.activation_format == \
FusedMoEActivationFormat.BatchedExperts:
prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
else:
hook = prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if hook is not None: if hook is not None:
hook() hook()
self.alt_event.record()
current_stream.wait_event(self.alt_event)
if self.shared_experts is not None: if self.shared_experts is not None:
return (shared_output, output) return (shared_output, output)
......
...@@ -885,7 +885,12 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -885,7 +885,12 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.is_mtp_layer: if self.is_mtp_layer:
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = hidden_states.tensor_split(self.tp_size)[self.tp_rank] ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
...@@ -893,6 +898,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -893,6 +898,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
...@@ -23,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDe ...@@ -23,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDe
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
from vllm.utils import round_up
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -186,12 +185,6 @@ class EagleProposer: ...@@ -186,12 +185,6 @@ class EagleProposer:
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
......
...@@ -47,7 +47,8 @@ from vllm.sequence import IntermediateTensors ...@@ -47,7 +47,8 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, get_dtype_size, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up) is_pin_memory_available, round_up,
round_down)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
...@@ -325,6 +326,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -325,6 +326,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`. # from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {} self.shared_kv_cache_layers: dict[str, str] = {}
self.ep_sp = False
self.dp_size = self.parallel_config.data_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size
self.enable_expert_parallel = self.parallel_config.enable_expert_parallel
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -1317,6 +1325,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1317,6 +1325,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else:
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs. # Use piecewise CUDA graphs.
...@@ -1334,12 +1352,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1334,12 +1352,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else: else:
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad num_input_tokens += num_pad
...@@ -2039,10 +2051,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2039,10 +2051,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size if self.ep_sp:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size if num_tokens < self.tp_size:
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1: num_tokens = self.tp_size
num_tokens = round_up(num_tokens, tp_size)
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
...@@ -2055,14 +2066,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2055,14 +2066,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
max_num_reqs = self.scheduler_config.max_num_seqs max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs) num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs min_tokens_per_req = num_tokens // num_reqs
num_actual_tokens = num_tokens
if not is_profile and self.speculative_config is not None \ if not is_profile and self.speculative_config is not None \
and self.speculative_config.num_lookahead_slots > 0 \ and self.speculative_config.num_lookahead_slots > 0 \
and num_tokens >= (1 + self.speculative_config.num_lookahead_slots): and num_tokens >= (1 + self.speculative_config.num_lookahead_slots):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots) min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
if not self.ep_sp:
num_scheduled_tokens_list[-1] += num_tokens % num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else:
num_scheduled_tokens_list[-1] += num_tokens % min_tokens_per_req
assert sum(num_scheduled_tokens_list) == num_tokens assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, num_scheduled_tokens = np.array(num_scheduled_tokens_list,
...@@ -2086,7 +2106,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2086,7 +2106,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_lens=seq_lens, seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor, # seq_lens_tensor=seq_lens_tensor,
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=num_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=num_tokens, max_query_len=num_tokens,
num_speculative_tokens=num_speculative_tokens, num_speculative_tokens=num_speculative_tokens,
) )
...@@ -3100,6 +3120,17 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3100,6 +3120,17 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_decode_metadata, spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else:
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs. # Use piecewise CUDA graphs.
...@@ -3117,12 +3148,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3117,12 +3148,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
else: else:
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad num_input_tokens += num_pad
......
...@@ -7,7 +7,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata ...@@ -7,7 +7,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
from vllm.utils import round_up
class V1ZeroEagleProposer(EagleProposer): class V1ZeroEagleProposer(EagleProposer):
...@@ -109,12 +108,6 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -109,12 +108,6 @@ class V1ZeroEagleProposer(EagleProposer):
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
......
...@@ -406,6 +406,17 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -406,6 +406,17 @@ class V1ZeroModelRunner(GPUModelRunner):
spec_decode_metadata, spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else:
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs. # Use piecewise CUDA graphs.
...@@ -423,12 +434,6 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -423,12 +434,6 @@ class V1ZeroModelRunner(GPUModelRunner):
else: else:
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad num_input_tokens += num_pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment