Commit aef3c487 authored by wangmin6's avatar wangmin6 Committed by zhangzbb
Browse files

[Feature]添加PCP功能,只支持mla架构,CPLB待验证

parent c1819454
...@@ -299,6 +299,13 @@ class ParallelConfig: ...@@ -299,6 +299,13 @@ class ParallelConfig:
should only be set by API server scale-out. should only be set by API server scale-out.
""" """
enable_lightly_cp: bool = False
"""Use lightly context parallel."""
enable_lightly_cplb: bool = False
"""Use lightly context parallel load balancing."""
@field_validator("disable_nccl_for_dp_synchronization", mode="wrap") @field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
@classmethod @classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
......
...@@ -1061,6 +1061,12 @@ class VllmConfig: ...@@ -1061,6 +1061,12 @@ class VllmConfig:
# Handle the KV connector configs # Handle the KV connector configs
self._post_init_kv_transfer_config() self._post_init_kv_transfer_config()
if self.parallel_config.enable_lightly_cp and not self.model_config.enforce_eager:
raise ValueError(
"Lightly context parallel currently only supports the eager mode!!!"
)
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when # remove the sizes that not multiple of tp_size when
# enable sequence parallelism # enable sequence parallelism
......
...@@ -585,6 +585,9 @@ class EngineArgs: ...@@ -585,6 +585,9 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
tokens_only: bool = False tokens_only: bool = False
enable_lightly_cp: bool = ParallelConfig.enable_lightly_cp
enable_lightly_cplb: bool = ParallelConfig.enable_lightly_cplb
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
...@@ -902,6 +905,15 @@ class EngineArgs: ...@@ -902,6 +905,15 @@ class EngineArgs:
"--worker-extension-cls", **parallel_kwargs["worker_extension_cls"] "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
) )
parallel_group.add_argument(
"--enable-lightly-cp",
**parallel_kwargs["enable_lightly_cp"],
)
parallel_group.add_argument(
"--enable-lightly-cplb",
**parallel_kwargs["enable_lightly_cplb"],
)
# KV cache arguments # KV cache arguments
cache_kwargs = get_kwargs(CacheConfig) cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group( cache_group = parser.add_argument_group(
...@@ -1661,6 +1673,8 @@ class EngineArgs: ...@@ -1661,6 +1673,8 @@ class EngineArgs:
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count, _api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank, _api_process_rank=self._api_process_rank,
enable_lightly_cp=self.enable_lightly_cp,
enable_lightly_cplb=self.enable_lightly_cplb,
) )
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
......
...@@ -324,6 +324,9 @@ if TYPE_CHECKING: ...@@ -324,6 +324,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_TOPK: bool = False USE_LIGHTOP_TOPK: bool = False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
VLLM_DISABLE_DSA: bool = False VLLM_DISABLE_DSA: bool = False
VLLM_LIGHTLY_CP_THRESHOULD: int = 2048
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME",
...@@ -2005,6 +2008,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -2005,6 +2008,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLE_DSA": "VLLM_DISABLE_DSA":
lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in
("true", "1")), ("true", "1")),
# MLA_CP open threshold
"VLLM_LIGHTLY_CP_THRESHOULD":
lambda: int(os.getenv("VLLM_LIGHTLY_CP_THRESHOULD", "2048")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -240,6 +240,11 @@ class ForwardContext: ...@@ -240,6 +240,11 @@ class ForwardContext:
additional_kwargs: dict[str, Any] = field(default_factory=dict) additional_kwargs: dict[str, Any] = field(default_factory=dict)
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
enable_lightly_cp: bool = False
enable_lightly_cplb : bool = False
def __post_init__(self): def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
...@@ -273,6 +278,10 @@ def create_forward_context( ...@@ -273,6 +278,10 @@ def create_forward_context(
slot_mapping: dict[str, torch.Tensor] | None = None, slot_mapping: dict[str, torch.Tensor] | None = None,
additional_kwargs: dict[str, Any] | None = None, additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
enable_lightly_cp: bool = False,
enable_lightly_cplb: bool = False
): ):
if vllm_config.compilation_config.fast_moe_cold_start: if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None: if vllm_config.speculative_config is None:
...@@ -298,6 +307,10 @@ def create_forward_context( ...@@ -298,6 +307,10 @@ def create_forward_context(
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
skip_compiled=skip_compiled, skip_compiled=skip_compiled,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_lightly_cp=enable_lightly_cp,
enable_lightly_cplb=enable_lightly_cplb,
additional_kwargs=additional_kwargs or {}, additional_kwargs=additional_kwargs or {},
) )
...@@ -329,6 +342,10 @@ def set_forward_context( ...@@ -329,6 +342,10 @@ def set_forward_context(
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
enable_lightly_cp: bool = False,
enable_lightly_cplb: bool = False,
): ):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
...@@ -390,6 +407,10 @@ def set_forward_context( ...@@ -390,6 +407,10 @@ def set_forward_context(
slot_mapping, slot_mapping,
additional_kwargs, additional_kwargs,
skip_compiled, skip_compiled,
scatter_indexes_tensor,
gather_indexes_tensor,
enable_lightly_cp,
enable_lightly_cplb
) )
try: try:
......
...@@ -7,8 +7,12 @@ import torch ...@@ -7,8 +7,12 @@ import torch
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.distributed import (
tensor_model_parallel_all_gather,
)
@dataclass @dataclass
...@@ -184,8 +188,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -184,8 +188,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
enable_lightly_cp = get_forward_context().enable_lightly_cp
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
# if not use_fused_rms_rope_concat: # if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if enable_lightly_cp:
kv_c_normed = tensor_model_parallel_all_gather(
kv_c_normed.contiguous(), 0
)
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if enable_lightly_cplb and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c_normed = torch.index_select(kv_c_normed, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
...@@ -221,6 +243,20 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -221,6 +243,20 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to " "VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"expose 'cos_sin_cache'." "expose 'cos_sin_cache'."
) )
if enable_lightly_cp:
kv_c = tensor_model_parallel_all_gather(
kv_c.contiguous(), 0
)
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if enable_lightly_cplb and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c = torch.index_select(kv_c, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
......
...@@ -90,7 +90,7 @@ def sparse_attn_indexer( ...@@ -90,7 +90,7 @@ def sparse_attn_indexer(
) )
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.forward_context import get_forward_context
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -36,6 +37,9 @@ from .deepseek_v2 import ( ...@@ -36,6 +37,9 @@ from .deepseek_v2 import (
DeepseekV2MoE, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name, get_spec_layer_idx_from_weight_name,
) )
from vllm.distributed import (tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from .utils import maybe_prefix from .utils import maybe_prefix
from .interfaces import SupportsPP from .interfaces import SupportsPP
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -191,7 +198,28 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -191,7 +198,28 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()
previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
else:
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)
previous_hidden_states = torch.index_select(previous_hidden_states, 0, scatter_indexes_tensor)
if positions is not None:
positions = torch.index_select(positions, 0, scatter_indexes_tensor)
hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids, input_ids,
positions, positions,
previous_hidden_states, previous_hidden_states,
...@@ -199,6 +227,14 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -199,6 +227,14 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx, current_step_idx,
) )
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)
return hidden_states
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -46,6 +46,8 @@ from vllm.distributed import ( ...@@ -46,6 +46,8 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.forward_context import get_forward_context
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...@@ -181,6 +183,44 @@ class DeepseekAttention(nn.Module): ...@@ -181,6 +183,44 @@ class DeepseekAttention(nn.Module):
return output return output
def iqis_all_gather(
iqis: tuple[torch.Tensor, torch.Tensor],
tp_size: int | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert iqis is not None
iq_tensor, is_tensor = iqis
assert isinstance(iq_tensor, torch.Tensor)
assert isinstance(is_tensor, torch.Tensor)
assert iq_tensor.dtype == torch.int8, f"iq_tensor dtype is {iq_tensor.dtype}"
assert is_tensor.dtype == torch.float32, f"is_tensor dtype is {is_tensor.dtype}"
assert iq_tensor.dim() == 2
assert is_tensor.dim() == 2
m_local, n = iq_tensor.shape
assert is_tensor.shape[0] == m_local, f"{is_tensor.shape[0]} != {iq_tensor.shape[0]}"
assert is_tensor.shape[1] == 1, f"is_tensor dim 1 ={is_tensor.shape[1]}"
iq_int8_2d = iq_tensor.view(torch.int8)
is_int8_2d = is_tensor.view(torch.int8)
combined_2d = torch.cat([iq_int8_2d, is_int8_2d], dim=1) # [m_local, n + 4]
if not combined_2d.is_contiguous():
combined_2d = combined_2d.contiguous()
combined_gathered = tensor_model_parallel_all_gather(combined_2d, dim=0)
split_idx = n
iq_gathered_int8 = combined_gathered[:, :split_idx].contiguous()
is_gathered_int8 = combined_gathered[:, split_idx:].contiguous()
iq_gathered = iq_gathered_int8.view(torch.int8)
assert iq_gathered.shape[0] == m_local * tp_size, f"iq_gathered dim0= {iq_gathered.shape[0]}, expected {m_local * tp_size}"
# is_gathered_int8 should be [m_local*tp_size, 4]
assert is_gathered_int8.shape[0] == m_local * tp_size, f"is_gathered_int8 dim0= {is_gathered_int8.shape[0]}, expected {m_local * tp_size}"
assert is_gathered_int8.shape[1] == 4, f"is_gathered_int8 dim1= {is_gathered_int8.shape[1]}"
is_gathered = is_gathered_int8.view(torch.float32)
return (iq_gathered, is_gathered)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -211,10 +251,85 @@ class DeepseekV2MLP(nn.Module): ...@@ -211,10 +251,85 @@ class DeepseekV2MLP(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, #reduce_results=reduce_results,
reduce_results=False,
disable_tp=is_sequence_parallel, disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
self.tp_size = get_tensor_model_parallel_world_size()
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self,
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
if iqis is not None and iqis[0] is not None and iqis[1] is not None:
iqis = iqis_all_gather(iqis, tp_size=self.tp_size)
else:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
if enable_lightly_cp:
x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0)
return x
elif self.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
class DeepseekV2SharedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj"
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now." f"Unsupported activation: {hidden_act}. Only silu is supported for now."
...@@ -311,7 +426,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -311,7 +426,7 @@ class DeepseekV2MoE(nn.Module):
else: else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2SharedMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -357,6 +472,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -357,6 +472,11 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(
hidden_states.contiguous(), 0
)
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -428,7 +548,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -428,7 +548,12 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None assert shared_output is not None
final_hidden_states += shared_output final_hidden_states += shared_output
if self.is_sequence_parallel: if enable_lightly_cp:
final_hidden_states = tensor_model_parallel_reduce_scatter(
final_hidden_states.contiguous(), 0
)
return final_hidden_states
elif self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
...@@ -759,6 +884,16 @@ class Indexer(nn.Module): ...@@ -759,6 +884,16 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA). # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
if enable_lightly_cplb and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor)
# we only quant q here since k quant is fused with cache insertion # we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim) q = q.view(-1, self.head_dim)
...@@ -825,7 +960,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -825,7 +960,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0 assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size self.num_local_heads = num_heads // tp_size if not \
vllm_config.parallel_config.enable_lightly_cp else self.num_heads
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
...@@ -859,6 +995,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -859,6 +995,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_b_proj", prefix=f"{prefix}.q_b_proj",
disable_tp=vllm_config.parallel_config.enable_lightly_cp
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -867,6 +1004,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -867,6 +1004,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_proj", prefix=f"{prefix}.q_proj",
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -875,6 +1013,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -875,6 +1013,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj", prefix=f"{prefix}.kv_b_proj",
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
...@@ -882,6 +1021,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -882,6 +1021,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
if config.rope_parameters["rope_type"] != "default": if config.rope_parameters["rope_type"] != "default":
...@@ -1118,6 +1258,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1118,6 +1258,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1.0 / self.routed_scaling_factor residual *= 1.0 / self.routed_scaling_factor
# Fully Connected # Fully Connected
enable_lightly_cp = get_forward_context().enable_lightly_cp
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
assert self.post_attention_layernorm.has_weight is True assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states, _i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
...@@ -1126,9 +1267,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1126,9 +1267,10 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input=update_hs update_input=update_hs
) )
new_resi = residual new_resi = residual
hidden_states = self.mlp(hidden_states, if enable_lightly_cp and isinstance(self.mlp, DeepseekV2MoE):
iqis=(_i_q, _i_s) hidden_states = self.mlp(hidden_states)
) else:
hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s))
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
...@@ -1225,6 +1367,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1225,6 +1367,9 @@ class DeepseekV2Model(nn.Module):
self.config = config self.config = config
self.device = current_platform.device_type self.device = current_platform.device_type
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
#添加判断,默认开启DSA #添加判断,默认开启DSA
force_disable_dsa = envs.VLLM_DISABLE_DSA force_disable_dsa = envs.VLLM_DISABLE_DSA
...@@ -1287,6 +1432,30 @@ class DeepseekV2Model(nn.Module): ...@@ -1287,6 +1432,30 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous()
if residual is not None:
residual_per_rank = torch.chunk(residual, chunks=self.tp_size, dim=0)
residual = residual_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
else:
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
hidden_states = torch.index_select(hidden_states, 0, scatter_indexes_tensor)
if residual is not None:
residual = torch.index_select(residual, 0, scatter_indexes_tensor)
if positions is not None:
positions = torch.index_select(positions, 0, scatter_indexes_tensor)
# Compute llama 4 scaling once per forward pass if enabled # Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None) llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
llama_4_scaling: torch.Tensor | None llama_4_scaling: torch.Tensor | None
...@@ -1307,11 +1476,21 @@ class DeepseekV2Model(nn.Module): ...@@ -1307,11 +1476,21 @@ class DeepseekV2Model(nn.Module):
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
residual = tensor_model_parallel_all_gather(residual.contiguous(), dim=0)
return IntermediateTensors( return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual} {"hidden_states": hidden_states, "residual": residual}
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)
return hidden_states return hidden_states
......
...@@ -282,6 +282,35 @@ class AttentionMetadata: ...@@ -282,6 +282,35 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata) T = TypeVar("T", bound=AttentionMetadata)
@dataclass
class CpCommonAttentionMetadata:
# sp related metadata
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
seq_lens: torch.Tensor
_seq_lens_cpu: torch.Tensor
num_actual_tokens: int
num_kv_actual_tokens: int
max_query_len: int
max_seq_len: int
num_reqs: int
req_ids: list[str]
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
_num_computed_tokens_cpu: torch.Tensor
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
def batch_size(self) -> int:
return self.seq_lens.shape[0]
@property
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@dataclass @dataclass
class CommonAttentionMetadata: class CommonAttentionMetadata:
...@@ -312,6 +341,14 @@ class CommonAttentionMetadata: ...@@ -312,6 +341,14 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
num_kv_actual_tokens: int | None = None
seq_indexes_list: list[int] | None = None
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
cp_common_metadata: CpCommonAttentionMetadata | None = None
enable_lightly_cp: bool = False
causal: bool = True causal: bool = True
# Needed by FastPrefillAttentionBuilder # Needed by FastPrefillAttentionBuilder
...@@ -396,6 +433,7 @@ class CommonAttentionMetadata: ...@@ -396,6 +433,7 @@ class CommonAttentionMetadata:
else None, else None,
num_reqs=num_actual_reqs, num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
num_kv_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
block_table_tensor=self.block_table_tensor[:num_actual_reqs], block_table_tensor=self.block_table_tensor[:num_actual_reqs],
......
...@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata): ...@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
max_seq_len: int max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
num_kv_actual_tokens: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
...@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_query_len=cm.max_query_len, max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len, max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens, num_actual_tokens=cm.num_actual_tokens,
num_kv_actual_tokens=cm.num_kv_actual_tokens,
query_start_loc=cm.query_start_loc, query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping, slot_mapping=cm.slot_mapping,
block_table=cm.block_table_tensor, block_table=cm.block_table_tensor,
...@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
return output.fill_(0) return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
num_kv_actual_toks = attn_metadata.num_kv_actual_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...] q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_kv_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...] k_pe = k_pe[:num_kv_actual_toks, ...]
assert self.topk_indices_buffer is not None assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks] topk_indices = self.topk_indices_buffer[:num_actual_toks]
......
...@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata: ...@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata:
max_seq_len: int max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
num_kv_actual_tokens: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# The dimension of the attention heads # The dimension of the attention heads
...@@ -438,6 +439,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -438,6 +439,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
max_query_len=common_attn_metadata.max_query_len, max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len, max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_tokens=common_attn_metadata.num_actual_tokens,
num_kv_actual_tokens=common_attn_metadata.num_kv_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc, query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128, head_dim=128,
......
...@@ -14,7 +14,7 @@ from vllm.config import ( ...@@ -14,7 +14,7 @@ from vllm.config import (
VllmConfig, VllmConfig,
get_layers_from_vllm_config, get_layers_from_vllm_config,
) )
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group, get_tensor_model_parallel_rank
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available ...@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
CpCommonAttentionMetadata,
) )
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import ( from vllm.v1.attention.backends.tree_attn import (
...@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import ( ...@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -76,7 +78,9 @@ class SpecDecodeBaseProposer: ...@@ -76,7 +78,9 @@ class SpecDecodeBaseProposer:
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
# The drafter can get longer sequences than the target model. # The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs max_batch_size = vllm_config.scheduler_config.max_num_seqs if not \
vllm_config.parallel_config.enable_lightly_cplb \
else vllm_config.scheduler_config.max_num_seqs * 2
self.max_num_tokens = ( self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
) )
...@@ -219,6 +223,28 @@ class SpecDecodeBaseProposer: ...@@ -219,6 +223,28 @@ class SpecDecodeBaseProposer:
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1) ).repeat(max_batch_size, 1)
self.scatter_indexes_tensor = None
self.gather_indexes_tensor = None
self.enable_lightly_cp = vllm_config.parallel_config.enable_lightly_cp
self.enable_lightly_cplb = self.enable_lightly_cp and vllm_config.parallel_config.enable_lightly_cplb
if self.enable_lightly_cp:
self.query_start_loc = CpuGpuBuffer(
max_batch_size + 1,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
self.seq_lens = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
def _get_positions(self, num_tokens: int): def _get_positions(self, num_tokens: int):
if self.uses_mrope: if self.uses_mrope:
return self.mrope_positions[:, :num_tokens] return self.mrope_positions[:, :num_tokens]
...@@ -270,6 +296,10 @@ class SpecDecodeBaseProposer: ...@@ -270,6 +296,10 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
return round_up(num_scheduled_tokens, tp_size)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
...@@ -309,6 +339,31 @@ class SpecDecodeBaseProposer: ...@@ -309,6 +339,31 @@ class SpecDecodeBaseProposer:
) )
) )
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
enable_lightly_cp = self.enable_lightly_cp and num_tokens > self.runner.lightly_cp_threshould
if enable_lightly_cp:
num_tokens_dp_padded = self._pad_for_mla_cp(num_tokens_dp_padded)
common_attn_metadata = self._prepare_cp_metadata(
num_reqs_padded=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
num_tokens=num_tokens,
block_table_gid_0=common_attn_metadata.block_table_tensor,
slot_mapping_gid_0=common_attn_metadata.slot_mapping,
query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=common_attn_metadata.query_start_loc_cpu,
seq_lens=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
self.scatter_indexes_tensor = common_attn_metadata.scatter_indexes_tensor
self.gather_indexes_tensor = common_attn_metadata.gather_indexes_tensor
assert self.runner is not None assert self.runner is not None
if self.attn_metadata_builder is None: if self.attn_metadata_builder is None:
...@@ -339,10 +394,6 @@ class SpecDecodeBaseProposer: ...@@ -339,10 +394,6 @@ class SpecDecodeBaseProposer:
assert draft_indexer_metadata is not None assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded num_tokens_dp_padded
) )
...@@ -387,6 +438,10 @@ class SpecDecodeBaseProposer: ...@@ -387,6 +438,10 @@ class SpecDecodeBaseProposer:
slot_mapping=self._get_slot_mapping( slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping num_input_tokens, common_attn_metadata.slot_mapping
), ),
scatter_indexes_tensor=self.scatter_indexes_tensor,
gather_indexes_tensor=self.gather_indexes_tensor,
enable_lightly_cp=self.enable_lightly_cp and num_tokens > self.runner.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
): ):
ret_hidden_states = self.model(**model_kwargs) ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple(): if not self.model_returns_tuple():
...@@ -463,6 +518,9 @@ class SpecDecodeBaseProposer: ...@@ -463,6 +518,9 @@ class SpecDecodeBaseProposer:
if batch_size_across_dp is not None: if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size batch_size_across_dp[self.dp_rank] = input_batch_size
if enable_lightly_cp:
common_attn_metadata = common_attn_metadata.cp_common_metadata
common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1 common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
...@@ -802,6 +860,7 @@ class SpecDecodeBaseProposer: ...@@ -802,6 +860,7 @@ class SpecDecodeBaseProposer:
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs, num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens, num_actual_tokens=total_num_tokens,
num_kv_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(), max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor, block_table_tensor=common_attn_metadata.block_table_tensor,
...@@ -989,6 +1048,104 @@ class SpecDecodeBaseProposer: ...@@ -989,6 +1048,104 @@ class SpecDecodeBaseProposer:
total_num_drafts = self.cu_drafts_per_level[level + 1] total_num_drafts = self.cu_drafts_per_level[level + 1]
return draft_token_ids_list return draft_token_ids_list
def _prepare_cp_metadata(
self,
num_reqs_padded,
max_query_len,
max_seq_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
query_start_loc,
query_start_loc_cpu,
seq_lens,
seq_lens_cpu,
num_computed_tokens_cpu,
):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank()
cp_common_metadata = CpCommonAttentionMetadata(
query_start_loc=query_start_loc.clone(),
query_start_loc_cpu=query_start_loc_cpu.clone(),
seq_lens=seq_lens.clone(),
_seq_lens_cpu=seq_lens_cpu.clone(),
max_query_len=max_query_len,
max_seq_len=max_seq_len,
num_reqs=num_reqs_padded,
req_ids=self.runner.input_batch.req_ids,
num_actual_tokens=num_tokens,
num_kv_actual_tokens=num_tokens,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
_num_computed_tokens_cpu=num_computed_tokens_cpu
)
q_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
kv_lens_cpu = seq_lens_cpu
total_q_len = num_tokens
total_kv_len = num_tokens
(
total_q_len,
q_lens_cpu,
seq_count,
kv_lens_cpu,
local_req_ids,
scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes_list,
) = self.runner._distribute_tokens_to_cp_ranks(
total_q_len,
q_lens_cpu,
kv_lens_cpu,
tp_rank,
tp_size,
self.runner.input_batch.req_ids,
)
num_reqs = seq_count
cu_num_tokens = np.cumsum(q_lens_cpu)
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
q_acc_lens = self.query_start_loc.gpu[: num_reqs + 1]
q_acc_lens_cpu = self.query_start_loc.cpu[: num_reqs + 1]
max_q_len = max(q_acc_lens_cpu)
self.seq_lens.np[:num_reqs] = kv_lens_cpu
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
kv_lens = self.seq_lens.gpu[:num_reqs]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs]
max_kv_len = max(kv_lens_cpu)
num_computed_tokens_cpu = kv_lens_cpu - q_acc_lens_cpu[1:]
blk_table_tensor = block_table_gid_0[seq_indexes_list]
cm_base = CommonAttentionMetadata(
query_start_loc=q_acc_lens,
query_start_loc_cpu=q_acc_lens_cpu,
seq_lens=kv_lens,
_seq_lens_cpu=kv_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_q_len,
max_query_len=max_q_len,
max_seq_len=max_kv_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping_gid_0,
causal=True,
num_kv_actual_tokens=total_kv_len,
seq_indexes_list=seq_indexes_list,
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
)
return cm_base
def prepare_inputs( def prepare_inputs(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
......
...@@ -234,6 +234,10 @@ class BlockTable: ...@@ -234,6 +234,10 @@ class BlockTable:
"""Returns the device tensor of the block table.""" """Returns the device tensor of the block table."""
return self.block_table.gpu[:num_reqs] return self.block_table.gpu[:num_reqs]
def get_device_tensor_range(self, start_req: int, end_req: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table.gpu[start_req:end_req]
def get_cpu_tensor(self) -> torch.Tensor: def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table.""" """Returns the CPU tensor of the block table."""
return self.block_table.cpu return self.block_table.cpu
......
...@@ -42,8 +42,13 @@ from vllm.distributed.parallel_state import ( ...@@ -42,8 +42,13 @@ from vllm.distributed.parallel_state import (
get_tp_group, get_tp_group,
graph_capture, graph_capture,
is_global_first_rank, is_global_first_rank,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
prepare_communication_buffer_for_model, prepare_communication_buffer_for_model,
) )
from vllm.distributed import (
tensor_model_parallel_all_gather
)
from vllm.forward_context import ( from vllm.forward_context import (
BatchDescriptor, BatchDescriptor,
set_forward_context, set_forward_context,
...@@ -104,6 +109,7 @@ from vllm.v1.attention.backend import ( ...@@ -104,6 +109,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
CpCommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
...@@ -372,10 +378,20 @@ class GPUModelRunner( ...@@ -372,10 +378,20 @@ class GPUModelRunner(
# Always set to false after the first forward pass # Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.calculate_kv_scales = self.cache_config.calculate_kv_scales
self.tp_size = self.parallel_config.tensor_parallel_size
self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs #self.max_num_reqs = scheduler_config.max_num_seqs
self.enable_lightly_cp = self.parallel_config.enable_lightly_cp
self.enable_lightly_cplb = self.enable_lightly_cp and self.parallel_config.enable_lightly_cplb
self.max_num_reqs = (
scheduler_config.max_num_seqs
if not self.enable_lightly_cplb
else scheduler_config.max_num_seqs * 2
)
self.lightly_cp_threshould = envs.VLLM_LIGHTLY_CP_THRESHOULD
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks # to make sure we are synced across pp ranks
...@@ -1490,6 +1506,243 @@ class GPUModelRunner( ...@@ -1490,6 +1506,243 @@ class GPUModelRunner(
return encoder_seq_lens, encoder_seq_lens_cpu return encoder_seq_lens, encoder_seq_lens_cpu
def _distribute_tokens_to_cp_ranks(
self,
total_q_len: int,
q_lens_cpu: np.ndarray,
kv_lens_cpu: np.ndarray,
tp_rank: int,
tp_size: int,
req_ids: list[str],
):
q_lens = []
seq_count = 0
seq_indexes = []
kv_lens = []
local_req_ids = []
local_scatter_indexes_tensor = None
gather_indexes_tensor = None
if self.enable_lightly_cplb:
rank_tokens = 0
rank_pad_tokens = 0
accu_q_start = 0
scatter_indexes: list[int] = []
num_requests = len(q_lens_cpu)
for i in range(num_requests):
req_q_len = q_lens_cpu[i]
req_pad_q_len = round_up(q_lens_cpu[i], 2 * tp_size)
kv_len = kv_lens_cpu[i]
chunk_q_len = req_pad_q_len // (2 * tp_size)
q_1_start = tp_rank * chunk_q_len
q_1_end = (tp_rank + 1) * chunk_q_len
q_2_start = req_pad_q_len - (tp_rank + 1) * chunk_q_len
q_2_end = req_pad_q_len - tp_rank * chunk_q_len
q_len_1 = (
chunk_q_len
if q_1_end <= req_q_len
else max(0, req_q_len - q_1_start)
)
q_len_2 = (
chunk_q_len
if q_2_end <= req_q_len
else max(0, req_q_len - q_2_start)
)
kv_len_1 = kv_len - req_q_len + min(req_q_len, q_1_end)
kv_len_2 = kv_len - req_q_len + min(req_q_len, q_2_end)
scatter_index1 = range(
accu_q_start + q_1_start, accu_q_start + q_1_start + q_len_1
)
scatter_index2 = range(
accu_q_start + q_2_start, accu_q_start + q_2_start + q_len_2
)
accu_q_start += req_q_len
if q_len_1 > 0:
q_lens.append(q_len_1)
kv_lens.append(kv_len_1)
seq_indexes.append(i)
local_req_ids.append(req_ids[i])
scatter_indexes.extend(scatter_index1)
seq_count += 1
rank_tokens += q_len_1
if q_len_2 > 0:
q_lens.append(q_len_2)
kv_lens.append(kv_len_2)
seq_indexes.append(i)
local_req_ids.append(req_ids[i])
scatter_indexes.extend(scatter_index2)
seq_count += 1
rank_tokens += q_len_2
rank_pad_tokens += chunk_q_len * 2
if len(scatter_indexes) < rank_pad_tokens:
scatter_indexes.extend([-1] * (rank_pad_tokens - len(scatter_indexes)))
local_scatter_indexes_tensor = torch.tensor(
scatter_indexes, dtype=torch.int64, device=self.device
)
global_scatter_indexes_tensor = tensor_model_parallel_all_gather(
local_scatter_indexes_tensor.contiguous(), dim=0
)
non_neg_mask = global_scatter_indexes_tensor != -1
non_neg_values = global_scatter_indexes_tensor[non_neg_mask]
non_neg_positions = torch.where(non_neg_mask)[0]
sorted_indices = torch.argsort(non_neg_values)
gather_indexes_tensor = non_neg_positions[sorted_indices]
if isinstance(rank_tokens, torch.Tensor):
rank_tokens = rank_tokens.item()
else:
tokens_per_rank = (total_q_len + tp_size - 1) // tp_size
start_token = tp_rank * tokens_per_rank
end_token = min((tp_rank + 1) * tokens_per_rank, total_q_len)
current_seq = 0
current_pos = 0
rank_tokens = min(tokens_per_rank, end_token - start_token)
while start_token < end_token and current_seq < len(q_lens_cpu):
q_len = q_lens_cpu[current_seq]
q_start = current_pos
q_end = current_pos + q_len
kv_len = kv_lens_cpu[current_seq]
# Find overlap between this sequence and rank's token range
overlap_start = max(start_token, q_start)
overlap_end = min(end_token, q_end)
if overlap_start < overlap_end:
# This sequence contributes tokens to this rank
token_count = overlap_end - overlap_start
q_lens.append(token_count)
start_token = overlap_end
seq_count += 1
seq_indexes.append(current_seq)
local_req_ids.append(req_ids[current_seq])
if q_end <= end_token:
kv_lens.append(kv_len)
else:
kv_lens.append(kv_len - (q_end - end_token))
current_pos = q_end
current_seq += 1
return (
rank_tokens,
np.array(q_lens, dtype=np.int32),
seq_count,
np.array(kv_lens, dtype=np.int32),
np.array(local_req_ids, dtype=str),
local_scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes,
)
def _prepare_cp_metadata(
self,
num_reqs_padded,
max_query_len,
max_seq_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
num_computed_tokens_cpu
):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank()
cp_common_metadata = CpCommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1].clone(),
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1].clone(),
seq_lens=self.seq_lens.gpu[:num_reqs_padded].clone(),
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded].clone(),
max_query_len=max_query_len,
max_seq_len=max_seq_len,
num_reqs=num_reqs_padded,
req_ids=self.input_batch.req_ids,
num_actual_tokens=num_tokens,
num_kv_actual_tokens=num_tokens,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
_num_computed_tokens_cpu=num_computed_tokens_cpu
)
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
q_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
total_q_len = num_tokens
total_kv_len = num_tokens
(
total_q_len,
q_lens_cpu,
seq_count,
kv_lens_cpu,
local_req_ids,
scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes_list,
) = self._distribute_tokens_to_cp_ranks(
total_q_len,
q_lens_cpu,
kv_lens_cpu,
tp_rank,
tp_size,
self.input_batch.req_ids,
)
num_reqs = seq_count
cu_num_tokens = np.cumsum(q_lens_cpu)
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
q_acc_lens = self.query_start_loc.gpu[: num_reqs + 1]
q_acc_lens_cpu = self.query_start_loc.cpu[: num_reqs + 1]
max_q_len = max(q_acc_lens_cpu)
self.seq_lens.np[:num_reqs] = kv_lens_cpu
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
kv_lens = self.seq_lens.gpu[:num_reqs]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs]
max_kv_len = max(kv_lens_cpu)
num_computed_tokens_cpu = kv_lens_cpu - q_acc_lens_cpu[1:]
blk_table_tensor = block_table_gid_0[seq_indexes_list]
cm_base = CommonAttentionMetadata(
query_start_loc=q_acc_lens,
query_start_loc_cpu=q_acc_lens_cpu,
seq_lens=kv_lens,
_seq_lens_cpu=kv_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_q_len,
max_query_len=max_q_len,
max_seq_len=max_kv_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping_gid_0,
causal=True,
num_kv_actual_tokens=total_kv_len,
seq_indexes_list=seq_indexes_list,
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_lightly_cp=True
)
return cm_base
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
...@@ -1723,13 +1976,20 @@ class GPUModelRunner( ...@@ -1723,13 +1976,20 @@ class GPUModelRunner(
num_scheduled_tokens: dict[str, int] | None = None, num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None,
slot_mappings: dict[int, torch.Tensor] | None = None, slot_mappings: dict[int, torch.Tensor] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: ) -> tuple[
PerLayerAttnMetadata,
CommonAttentionMetadata | None,
torch.Tensor | None,
torch.Tensor | None,
]:
""" """
:return: tuple[attn_metadata, spec_decode_common_attn_metadata] :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
""" """
# Attention metadata is not needed for attention free models # Attention metadata is not needed for attention free models
if len(self.kv_cache_config.kv_cache_groups) == 0: if len(self.kv_cache_config.kv_cache_groups) == 0:
return {}, None return {}, None, None, None
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
num_tokens_padded = num_tokens_padded or num_tokens num_tokens_padded = num_tokens_padded or num_tokens
num_reqs_padded = num_reqs_padded or num_reqs num_reqs_padded = num_reqs_padded or num_reqs
...@@ -1777,9 +2037,14 @@ class GPUModelRunner( ...@@ -1777,9 +2037,14 @@ class GPUModelRunner(
assert slot_mappings is not None assert slot_mappings is not None
block_table_gid_0 = _get_block_table(0) block_table_gid_0 = _get_block_table(0)
slot_mapping_gid_0 = slot_mappings[0] slot_mapping_gid_0 = slot_mappings[0]
scatter_indexes_tensor = None
gather_indexes_tensor = None
if self.model_config.enable_return_routed_experts: if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
enable_lightly_cp = self.enable_lightly_cp and num_tokens > self.lightly_cp_threshould
if not enable_lightly_cp:
cm_base = CommonAttentionMetadata( cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
...@@ -1790,12 +2055,27 @@ class GPUModelRunner( ...@@ -1790,12 +2055,27 @@ class GPUModelRunner(
], ],
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded, num_actual_tokens=num_tokens_padded,
num_kv_actual_tokens=num_tokens_padded,
max_query_len=max_query_len, max_query_len=max_query_len,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
block_table_tensor=block_table_gid_0, block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0, slot_mapping=slot_mapping_gid_0,
causal=True, causal=True,
) )
else:
cm_base = self._prepare_cp_metadata(
num_reqs_padded,
max_query_len,
max_seq_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
],
)
scatter_indexes_tensor = cm_base.scatter_indexes_tensor
gather_indexes_tensor = cm_base.gather_indexes_tensor
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
...@@ -1906,12 +2186,23 @@ class GPUModelRunner( ...@@ -1906,12 +2186,23 @@ class GPUModelRunner(
cm.block_table_tensor = _get_block_table(kv_cache_gid) cm.block_table_tensor = _get_block_table(kv_cache_gid)
cm.slot_mapping = slot_mappings[kv_cache_gid] cm.slot_mapping = slot_mappings[kv_cache_gid]
if enable_lightly_cp and cm.seq_indexes_list is not None:
cm.block_table_tensor = cm.block_table_tensor[cm.seq_indexes_list]
if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"): if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
if enable_lightly_cp:
spec_decode_common_attn_metadata = cm.cp_common_metadata
else:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm
else:
if enable_lightly_cp:
spec_decode_common_attn_metadata = cm.cp_common_metadata
else: else:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm
for attn_gid in range(len(self.attn_groups[kv_cache_gid])): for attn_gid in range(len(self.attn_groups[kv_cache_gid])):
if ubatch_slices is not None: if ubatch_slices is not None:
...@@ -1941,8 +2232,10 @@ class GPUModelRunner( ...@@ -1941,8 +2232,10 @@ class GPUModelRunner(
for _metadata in attn_metadata.values(): for _metadata in attn_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
if spec_decode_common_attn_metadata is not None and ( if (
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded (not self.enable_lightly_cp)
and spec_decode_common_attn_metadata is not None
and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded)
): ):
# Currently the drafter still only uses piecewise cudagraphs (and modifies # Currently the drafter still only uses piecewise cudagraphs (and modifies
# the attention metadata in directly), and therefore does not want to use # the attention metadata in directly), and therefore does not want to use
...@@ -1951,7 +2244,12 @@ class GPUModelRunner( ...@@ -1951,7 +2244,12 @@ class GPUModelRunner(
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
) )
return attn_metadata, spec_decode_common_attn_metadata return (
attn_metadata,
spec_decode_common_attn_metadata,
scatter_indexes_tensor,
gather_indexes_tensor
)
def _compute_cascade_attn_prefix_lens( def _compute_cascade_attn_prefix_lens(
self, self,
...@@ -2803,9 +3101,20 @@ class GPUModelRunner( ...@@ -2803,9 +3101,20 @@ class GPUModelRunner(
return model_runner_output return model_runner_output
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
# if num_scheduled_tokens <= tp_size * tp_size:
# return num_scheduled_tokens
# else:
# return round_up(num_scheduled_tokens, tp_size)
return round_up(num_scheduled_tokens, tp_size)
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when # Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP # enabled collective fusion for SP
if self.enable_lightly_cp and num_scheduled_tokens > self.lightly_cp_threshould:
return self._pad_for_mla_cp(num_scheduled_tokens)
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config.enable_sp and tp_size > 1: if self.compilation_config.pass_config.enable_sp and tp_size > 1:
return round_up(num_scheduled_tokens, tp_size) return round_up(num_scheduled_tokens, tp_size)
...@@ -3502,6 +3811,8 @@ class GPUModelRunner( ...@@ -3502,6 +3811,8 @@ class GPUModelRunner(
) )
num_tokens_padded = batch_desc.num_tokens num_tokens_padded = batch_desc.num_tokens
if self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould:
num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded)
num_reqs_padded = ( num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
) )
...@@ -3558,8 +3869,12 @@ class GPUModelRunner( ...@@ -3558,8 +3869,12 @@ class GPUModelRunner(
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
) )
attn_metadata, spec_decode_common_attn_metadata = ( (
self._build_attention_metadata( attn_metadata,
spec_decode_common_attn_metadata,
scatter_indexes_tensor,
gather_indexes_tensor,
) = self._build_attention_metadata(
num_tokens=num_tokens_unpadded, num_tokens=num_tokens_unpadded,
num_tokens_padded=num_tokens_padded if pad_attn else None, num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs, num_reqs=num_reqs,
...@@ -3572,7 +3887,6 @@ class GPUModelRunner( ...@@ -3572,7 +3887,6 @@ class GPUModelRunner(
cascade_attn_prefix_lens=cascade_attn_prefix_lens, cascade_attn_prefix_lens=cascade_attn_prefix_lens,
slot_mappings=slot_mappings_by_group, slot_mappings=slot_mappings_by_group,
) )
)
( (
input_ids, input_ids,
...@@ -3614,6 +3928,10 @@ class GPUModelRunner( ...@@ -3614,6 +3928,10 @@ class GPUModelRunner(
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
skip_compiled=has_encoder_input, skip_compiled=has_encoder_input,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_lightly_cp=self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
), ),
record_function_or_nullcontext("gpu_model_runner: forward"), record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output( self.maybe_get_kv_connector_output(
...@@ -4105,6 +4423,15 @@ class GPUModelRunner( ...@@ -4105,6 +4423,15 @@ class GPUModelRunner(
spec_decode_metadata, spec_decode_metadata,
valid_sampled_tokens_count, valid_sampled_tokens_count,
) )
#total_num_tokens = common_attn_metadata.num_actual_tokens
if (
self.enable_lightly_cp
and common_attn_metadata.cp_common_metadata is not None
):
total_num_tokens = (
common_attn_metadata.cp_common_metadata.num_actual_tokens
)
else:
total_num_tokens = common_attn_metadata.num_actual_tokens total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range # When padding the batch, token_indices is just a range
target_token_ids = self.input_ids.gpu[:total_num_tokens] target_token_ids = self.input_ids.gpu[:total_num_tokens]
...@@ -4759,7 +5086,7 @@ class GPUModelRunner( ...@@ -4759,7 +5086,7 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
attn_metadata, _ = self._build_attention_metadata( attn_metadata, _, _, _ = self._build_attention_metadata(
num_tokens=num_tokens_unpadded, num_tokens=num_tokens_unpadded,
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
max_query_len=max_query_len, max_query_len=max_query_len,
...@@ -4830,6 +5157,8 @@ class GPUModelRunner( ...@@ -4830,6 +5157,8 @@ class GPUModelRunner(
batch_descriptor=batch_desc, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
enable_lightly_cp=self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
), ),
): ):
outputs = self.model( outputs = self.model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment