Commit fd894e48 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds-wm-1222' into 'v0.9.2-dev-ds'

[fix]1.解决ep sequence parallel优化引入的mtp报错;2.解决共享专家无法和combine overlap问题

See merge request dcutoolkit/deeplearing/vllm!312
parents 8813afd8 dc01fce4
......@@ -52,7 +52,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
LayerBlockType, LazyLoader, common_broadcastable_dtype,
cuda_device_count_stateless, get_cpu_memory,
get_open_port, is_torch_equal_or_newer, random_uuid,
resolve_obj_by_qualname)
resolve_obj_by_qualname, round_up)
from vllm.utils import SUPPORT_TC
# yapf: enable
......@@ -4778,12 +4778,23 @@ class VllmConfig:
if size <= max_num_tokens
]
# add for ep sp
dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
mtp_batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list))
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
else:
if ep_sp:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
......
......@@ -809,6 +809,21 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_ids, apply_router_weight_on_input)
return output
_alt_stream: torch.cuda.Stream | None = None
def alt_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _alt_stream
# TODO: validate this works properly on ROCm platform.
if _alt_stream is None:
_alt_stream = torch.cuda.Stream()
return _alt_stream
@final
class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
......@@ -835,6 +850,10 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self.fused_experts = fused_experts
self.shared_experts = shared_experts
if self.shared_experts is not None:
self.alt_stream = alt_stream()
self.alt_event = torch.cuda.Event()
# assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}."
......@@ -863,6 +882,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
**_
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
......@@ -978,14 +998,27 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
)
shared_output = None
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
self.alt_event.record()
current_stream = torch.cuda.current_stream()
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if hook is not None:
hook()
shared_output = self.shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
self.alt_stream.wait_event(self.alt_event)
hook = None
if self.prepare_finalize.activation_format == \
FusedMoEActivationFormat.BatchedExperts:
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if hook is not None:
hook()
self.alt_event.record()
current_stream.wait_event(self.alt_event)
if self.shared_experts is not None:
return (shared_output, output)
......
......@@ -882,8 +882,13 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = hidden_states.tensor_split(self.tp_size)[self.tp_rank]
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
hidden_states = self.mlp(hidden_states)
......@@ -891,6 +896,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
......@@ -23,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDe
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
from vllm.utils import round_up
logger = init_logger(__name__)
......@@ -186,12 +185,6 @@ class EagleProposer:
else:
num_input_tokens = num_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
......
......@@ -47,7 +47,8 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up)
is_pin_memory_available, round_up,
round_down)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
......@@ -325,6 +326,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
self.ep_sp = False
self.dp_size = self.parallel_config.data_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size
self.enable_expert_parallel = self.parallel_config.enable_expert_parallel
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
......@@ -1314,28 +1322,33 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
......@@ -2036,10 +2049,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_tokens = round_up(num_tokens, tp_size)
if self.ep_sp:
if num_tokens < self.tp_size:
num_tokens = self.tp_size
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
......@@ -2052,14 +2064,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_actual_tokens = num_tokens
if not is_profile and self.speculative_config is not None \
and self.speculative_config.num_lookahead_slots > 0 \
and num_tokens >= (1 + self.speculative_config.num_lookahead_slots):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
if not self.ep_sp:
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else:
num_scheduled_tokens_list[-1] += num_tokens % min_tokens_per_req
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
......@@ -2083,7 +2104,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_lens=seq_lens,
# seq_lens_tensor=seq_lens_tensor,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
num_actual_tokens=num_actual_tokens,
max_query_len=num_tokens,
num_speculative_tokens=num_speculative_tokens,
)
......@@ -3097,28 +3118,33 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
......
......@@ -7,7 +7,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
from vllm.utils import round_up
class V1ZeroEagleProposer(EagleProposer):
......@@ -109,12 +108,6 @@ class V1ZeroEagleProposer(EagleProposer):
else:
num_input_tokens = num_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
......
......@@ -406,28 +406,33 @@ class V1ZeroModelRunner(GPUModelRunner):
spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
if self.ep_sp:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens)
else:
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment