Commit f28328d2 authored by 王敏's avatar 王敏
Browse files

[feat]pcp支持去掉torch compile后的精度验证

parent 6c6c9c0d
...@@ -1186,8 +1186,8 @@ class VllmConfig: ...@@ -1186,8 +1186,8 @@ class VllmConfig:
if ( if (
self.parallel_config.tensor_parallel_size > 1 self.parallel_config.tensor_parallel_size > 1
and (self.compilation_config.pass_config.enable_sp and (self.compilation_config.pass_config.enable_sp)
or envs.VLLM_MLA_CP) #or envs.VLLM_MLA_CP)
): ):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism( cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes cudagraph_capture_sizes
......
...@@ -209,7 +209,7 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -209,7 +209,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
request_id (str): request id for log request_id (str): request id for log
""" """
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()): if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) or dst_kv_cache_layer.ndim == 3:
num_pages = dst_kv_cache_layer_shape[0] num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1] page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape( dst_kv_cache_layer = dst_kv_cache_layer.reshape(
...@@ -379,7 +379,7 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -379,7 +379,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
assert self.du_swift_engine is not None assert self.du_swift_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata) is_mla = isinstance(attn_metadata, MLACommonMetadata) or kv_layer.ndim == 3
def extract_kv_from_layer( def extract_kv_from_layer(
layer: torch.Tensor, layer: torch.Tensor,
...@@ -390,7 +390,7 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -390,7 +390,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx) Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise. if MLA is not used, and (num_pages, page_size, xxx) otherwise.
""" """
if isinstance(attn_metadata, MLACommonMetadata): if isinstance(attn_metadata, MLACommonMetadata) or layer.ndim == 3:
num_pages, page_size = layer.shape[0], layer.shape[1] num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...] ...]
......
...@@ -242,6 +242,7 @@ class ForwardContext: ...@@ -242,6 +242,7 @@ class ForwardContext:
scatter_indexes_tensor: torch.Tensor | None = None scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None gather_indexes_tensor: torch.Tensor | None = None
enable_mla_cp: bool = False
def __post_init__(self): def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
...@@ -278,6 +279,7 @@ def create_forward_context( ...@@ -278,6 +279,7 @@ def create_forward_context(
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None, scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None, gather_indexes_tensor: torch.Tensor | None = None,
enable_mla_cp: bool = False
): ):
if vllm_config.compilation_config.fast_moe_cold_start: if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None: if vllm_config.speculative_config is None:
...@@ -305,6 +307,7 @@ def create_forward_context( ...@@ -305,6 +307,7 @@ def create_forward_context(
skip_compiled=skip_compiled, skip_compiled=skip_compiled,
scatter_indexes_tensor=scatter_indexes_tensor, scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor, gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=enable_mla_cp,
additional_kwargs=additional_kwargs or {}, additional_kwargs=additional_kwargs or {},
) )
...@@ -338,6 +341,7 @@ def set_forward_context( ...@@ -338,6 +341,7 @@ def set_forward_context(
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None, scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None, gather_indexes_tensor: torch.Tensor | None = None,
enable_mla_cp: bool = False,
): ):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
...@@ -400,6 +404,7 @@ def set_forward_context( ...@@ -400,6 +404,7 @@ def set_forward_context(
skip_compiled, skip_compiled,
scatter_indexes_tensor, scatter_indexes_tensor,
gather_indexes_tensor, gather_indexes_tensor,
enable_mla_cp
) )
try: try:
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.distributed import ( from vllm.distributed import (
...@@ -187,7 +188,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -187,7 +188,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP # and not get_forward_context().draft_model
# if not use_fused_rms_rope_concat: # if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......
...@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri ...@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
from vllm.v1.worker.workspace import current_workspace_manager from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt from lightop import op, gemmopt
from vllm.attention.utils.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
elif current_platform.is_xpu(): elif current_platform.is_xpu():
...@@ -28,9 +32,10 @@ elif current_platform.is_xpu(): ...@@ -28,9 +32,10 @@ elif current_platform.is_xpu():
logger = init_logger(__name__) logger = init_logger(__name__)
@maybe_transfer_kv_layer
def sparse_attn_indexer( def sparse_attn_indexer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
k_cache_prefix: str, layer_name:str,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -56,7 +61,7 @@ def sparse_attn_indexer( ...@@ -56,7 +61,7 @@ def sparse_attn_indexer(
) )
return sparse_attn_indexer_fake( return sparse_attn_indexer_fake(
hidden_states, hidden_states,
k_cache_prefix, layer_name,
kv_cache, kv_cache,
q_fp8, q_fp8,
k, k,
...@@ -69,7 +74,7 @@ def sparse_attn_indexer( ...@@ -69,7 +74,7 @@ def sparse_attn_indexer(
total_seq_lens, total_seq_lens,
topk_indices_buffer, topk_indices_buffer,
) )
attn_metadata = attn_metadata[k_cache_prefix] attn_metadata = attn_metadata[layer_name]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens] slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
...@@ -282,7 +287,7 @@ def sparse_attn_indexer( ...@@ -282,7 +287,7 @@ def sparse_attn_indexer(
def sparse_attn_indexer_fake( def sparse_attn_indexer_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
k_cache_prefix: str, layer_name: str,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
......
...@@ -46,6 +46,7 @@ from vllm.distributed import ( ...@@ -46,6 +46,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.forward_context import get_forward_context
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter from vllm.distributed.communication_op import tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -228,7 +229,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -228,7 +229,7 @@ class DeepseekV2MLP(nn.Module):
x, x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
): ):
enable_mla_cp = envs.VLLM_MLA_CP# and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if enable_mla_cp: if enable_mla_cp:
x = tensor_model_parallel_all_gather( x = tensor_model_parallel_all_gather(
x.contiguous(), 0 x.contiguous(), 0
...@@ -249,6 +250,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -249,6 +250,7 @@ class DeepseekV2MLP(nn.Module):
if enable_mla_cp: if enable_mla_cp:
x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0) x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0)
return x
elif self.tp_size > 1: elif self.tp_size > 1:
x = tensor_model_parallel_all_reduce(x) x = tensor_model_parallel_all_reduce(x)
return x return x
...@@ -430,7 +432,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -430,7 +432,7 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
enable_mla_cp = envs.VLLM_MLA_CP #and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP #and not get_forward_context().draft_model
if enable_mla_cp: if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather( hidden_states = tensor_model_parallel_all_gather(
hidden_states.contiguous(), 0 hidden_states.contiguous(), 0
...@@ -839,7 +841,7 @@ class Indexer(nn.Module): ...@@ -839,7 +841,7 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA). # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp: if enable_mla_cp:
k = tensor_model_parallel_all_gather( k = tensor_model_parallel_all_gather(
k.contiguous(), 0 k.contiguous(), 0
...@@ -1376,7 +1378,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1376,7 +1378,7 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp: if enable_mla_cp:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0) hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous() hidden_states = hidden_states_per_rank[self.tp_rank].contiguous()
......
...@@ -387,6 +387,7 @@ class GPUModelRunner( ...@@ -387,6 +387,7 @@ class GPUModelRunner(
if not envs.VLLM_MLA_CPLB if not envs.VLLM_MLA_CPLB
else scheduler_config.max_num_seqs * 2 else scheduler_config.max_num_seqs * 2
) )
self.mla_cp_threshould = 512
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks # to make sure we are synced across pp ranks
...@@ -2027,7 +2028,7 @@ class GPUModelRunner( ...@@ -2027,7 +2028,7 @@ class GPUModelRunner(
if self.model_config.enable_return_routed_experts: if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
if not envs.VLLM_MLA_CP or num_tokens <= tp_size * tp_size: if not envs.VLLM_MLA_CP or num_tokens <= self.mla_cp_threshould:
cm_base = CommonAttentionMetadata( cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
...@@ -3074,16 +3075,17 @@ class GPUModelRunner( ...@@ -3074,16 +3075,17 @@ class GPUModelRunner(
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int: def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if num_scheduled_tokens <= tp_size * tp_size: # if num_scheduled_tokens <= tp_size * tp_size:
return num_scheduled_tokens * tp_size # return num_scheduled_tokens
else: # else:
# return round_up(num_scheduled_tokens, tp_size)
return round_up(num_scheduled_tokens, tp_size) return round_up(num_scheduled_tokens, tp_size)
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when # Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP # enabled collective fusion for SP
if envs.VLLM_MLA_CP: if envs.VLLM_MLA_CP and num_scheduled_tokens > self.mla_cp_threshould:
return self._pad_for_mla_cp(num_scheduled_tokens) return self._pad_for_mla_cp(num_scheduled_tokens)
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config.enable_sp and tp_size > 1: if self.compilation_config.pass_config.enable_sp and tp_size > 1:
...@@ -3781,7 +3783,7 @@ class GPUModelRunner( ...@@ -3781,7 +3783,7 @@ class GPUModelRunner(
) )
num_tokens_padded = batch_desc.num_tokens num_tokens_padded = batch_desc.num_tokens
if envs.VLLM_MLA_CP: if envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould:
num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded) num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded)
num_reqs_padded = ( num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
...@@ -3899,6 +3901,7 @@ class GPUModelRunner( ...@@ -3899,6 +3901,7 @@ class GPUModelRunner(
skip_compiled=has_encoder_input, skip_compiled=has_encoder_input,
scatter_indexes_tensor=scatter_indexes_tensor, scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor, gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould,
), ),
record_function_or_nullcontext("gpu_model_runner: forward"), record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
...@@ -4918,8 +4921,8 @@ class GPUModelRunner( ...@@ -4918,8 +4921,8 @@ class GPUModelRunner(
or cudagraph_runtime_mode.valid_runtime_modes() or cudagraph_runtime_mode.valid_runtime_modes()
) )
if envs.VLLM_MLA_CP: # if envs.VLLM_MLA_CP:
num_tokens = max(self.tp_size, num_tokens) # num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and # If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using # cudagraph_mode.separate_routine(). This means that we are using
...@@ -5125,6 +5128,7 @@ class GPUModelRunner( ...@@ -5125,6 +5128,7 @@ class GPUModelRunner(
batch_descriptor=batch_desc, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould
), ),
): ):
outputs = self.model( outputs = self.model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment