Commit f28328d2 authored by 王敏's avatar 王敏
Browse files

[feat]pcp支持去掉torch compile后的精度验证

parent 6c6c9c0d
......@@ -1186,8 +1186,8 @@ class VllmConfig:
if (
self.parallel_config.tensor_parallel_size > 1
and (self.compilation_config.pass_config.enable_sp
or envs.VLLM_MLA_CP)
and (self.compilation_config.pass_config.enable_sp)
#or envs.VLLM_MLA_CP)
):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes
......
......@@ -209,7 +209,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
request_id (str): request id for log
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) or dst_kv_cache_layer.ndim == 3:
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
......@@ -379,7 +379,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
assert self.du_swift_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
is_mla = isinstance(attn_metadata, MLACommonMetadata) or kv_layer.ndim == 3
def extract_kv_from_layer(
layer: torch.Tensor,
......@@ -390,7 +390,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
if isinstance(attn_metadata, MLACommonMetadata) or layer.ndim == 3:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
......
......@@ -242,6 +242,7 @@ class ForwardContext:
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
enable_mla_cp: bool = False
def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
......@@ -278,6 +279,7 @@ def create_forward_context(
skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
enable_mla_cp: bool = False
):
if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None:
......@@ -305,6 +307,7 @@ def create_forward_context(
skip_compiled=skip_compiled,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=enable_mla_cp,
additional_kwargs=additional_kwargs or {},
)
......@@ -338,6 +341,7 @@ def set_forward_context(
skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
enable_mla_cp: bool = False,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
......@@ -400,6 +404,7 @@ def set_forward_context(
skip_compiled,
scatter_indexes_tensor,
gather_indexes_tensor,
enable_mla_cp
)
try:
......
......@@ -7,6 +7,7 @@ import torch
from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig
import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.distributed import (
......@@ -187,7 +188,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None:
q *= llama_4_scaling
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP # and not get_forward_context().draft_model
# if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......
......@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt
from vllm.attention.utils.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
......@@ -28,9 +32,10 @@ elif current_platform.is_xpu():
logger = init_logger(__name__)
@maybe_transfer_kv_layer
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
layer_name:str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
......@@ -56,7 +61,7 @@ def sparse_attn_indexer(
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
layer_name,
kv_cache,
q_fp8,
k,
......@@ -69,7 +74,7 @@ def sparse_attn_indexer(
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
attn_metadata = attn_metadata[layer_name]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
has_decode = attn_metadata.num_decodes > 0
......@@ -282,7 +287,7 @@ def sparse_attn_indexer(
def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
layer_name: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
......
......@@ -46,6 +46,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from vllm.forward_context import get_forward_context
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -228,7 +229,7 @@ class DeepseekV2MLP(nn.Module):
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
enable_mla_cp = envs.VLLM_MLA_CP# and not get_forward_context().draft_model
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if enable_mla_cp:
x = tensor_model_parallel_all_gather(
x.contiguous(), 0
......@@ -249,6 +250,7 @@ class DeepseekV2MLP(nn.Module):
if enable_mla_cp:
x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0)
return x
elif self.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
......@@ -430,7 +432,7 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
enable_mla_cp = envs.VLLM_MLA_CP #and not get_forward_context().draft_model
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP #and not get_forward_context().draft_model
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(
hidden_states.contiguous(), 0
......@@ -839,7 +841,7 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
......@@ -1376,7 +1378,7 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous()
......
......@@ -387,6 +387,7 @@ class GPUModelRunner(
if not envs.VLLM_MLA_CPLB
else scheduler_config.max_num_seqs * 2
)
self.mla_cp_threshould = 512
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
......@@ -2027,7 +2028,7 @@ class GPUModelRunner(
if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
if not envs.VLLM_MLA_CP or num_tokens <= tp_size * tp_size:
if not envs.VLLM_MLA_CP or num_tokens <= self.mla_cp_threshould:
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
......@@ -3074,16 +3075,17 @@ class GPUModelRunner(
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if num_scheduled_tokens <= tp_size * tp_size:
return num_scheduled_tokens * tp_size
else:
return round_up(num_scheduled_tokens, tp_size)
# if num_scheduled_tokens <= tp_size * tp_size:
# return num_scheduled_tokens
# else:
# return round_up(num_scheduled_tokens, tp_size)
return round_up(num_scheduled_tokens, tp_size)
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
if envs.VLLM_MLA_CP:
if envs.VLLM_MLA_CP and num_scheduled_tokens > self.mla_cp_threshould:
return self._pad_for_mla_cp(num_scheduled_tokens)
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config.enable_sp and tp_size > 1:
......@@ -3781,7 +3783,7 @@ class GPUModelRunner(
)
num_tokens_padded = batch_desc.num_tokens
if envs.VLLM_MLA_CP:
if envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould:
num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded)
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
......@@ -3899,6 +3901,7 @@ class GPUModelRunner(
skip_compiled=has_encoder_input,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
......@@ -4918,8 +4921,8 @@ class GPUModelRunner(
or cudagraph_runtime_mode.valid_runtime_modes()
)
if envs.VLLM_MLA_CP:
num_tokens = max(self.tp_size, num_tokens)
# if envs.VLLM_MLA_CP:
# num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
......@@ -5125,6 +5128,7 @@ class GPUModelRunner(
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould
),
):
outputs = self.model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment