Commit 35c61e91 authored by 王敏's avatar 王敏
Browse files

[Feature]添加PCP功能,目前暂时只支持mla架构,后续逐步适配

parent 6efaf21a
...@@ -299,6 +299,13 @@ class ParallelConfig: ...@@ -299,6 +299,13 @@ class ParallelConfig:
should only be set by API server scale-out. should only be set by API server scale-out.
""" """
enable_lightly_cp: bool = False
"""Use lightly context parallel."""
enable_lightly_cplb: bool = False
"""Use lightly context parallel load balancing."""
@field_validator("disable_nccl_for_dp_synchronization", mode="wrap") @field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
@classmethod @classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
......
...@@ -1061,6 +1061,12 @@ class VllmConfig: ...@@ -1061,6 +1061,12 @@ class VllmConfig:
# Handle the KV connector configs # Handle the KV connector configs
self._post_init_kv_transfer_config() self._post_init_kv_transfer_config()
if self.parallel_config.enable_lightly_cp and not self.model_config.enforce_eager:
raise ValueError(
"Lightly context parallel currently only supports the eager mode!!!"
)
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when # remove the sizes that not multiple of tp_size when
# enable sequence parallelism # enable sequence parallelism
......
...@@ -582,6 +582,9 @@ class EngineArgs: ...@@ -582,6 +582,9 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
tokens_only: bool = False tokens_only: bool = False
enable_lightly_cp: bool = ParallelConfig.enable_lightly_cp
enable_lightly_cplb: bool = ParallelConfig.enable_lightly_cplb
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
...@@ -899,6 +902,15 @@ class EngineArgs: ...@@ -899,6 +902,15 @@ class EngineArgs:
"--worker-extension-cls", **parallel_kwargs["worker_extension_cls"] "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
) )
parallel_group.add_argument(
"--enable-lightly-cp",
**parallel_kwargs["enable_lightly_cp"],
)
parallel_group.add_argument(
"--enable-lightly-cplb",
**parallel_kwargs["enable_lightly_cplb"],
)
# KV cache arguments # KV cache arguments
cache_kwargs = get_kwargs(CacheConfig) cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group( cache_group = parser.add_argument_group(
...@@ -1630,6 +1642,8 @@ class EngineArgs: ...@@ -1630,6 +1642,8 @@ class EngineArgs:
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count, _api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank, _api_process_rank=self._api_process_rank,
enable_lightly_cp=self.enable_lightly_cp,
enable_lightly_cplb=self.enable_lightly_cplb,
) )
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
......
...@@ -324,6 +324,9 @@ if TYPE_CHECKING: ...@@ -324,6 +324,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_TOPK: bool = False USE_LIGHTOP_TOPK: bool = False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
VLLM_DISABLE_DSA: bool = False VLLM_DISABLE_DSA: bool = False
VLLM_LIGHTLY_CP_THRESHOULD: int = 2048
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME",
...@@ -2004,7 +2007,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -2004,7 +2007,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
#If set to 1/True, disenable the DSA. #If set to 1/True, disenable the DSA.
"VLLM_DISABLE_DSA": "VLLM_DISABLE_DSA":
lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in
("true", "1")), ("true", "1")),
# MLA_CP open threshold
"VLLM_LIGHTLY_CP_THRESHOULD":
lambda: int(os.getenv("VLLM_LIGHTLY_CP_THRESHOULD", "2048")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -240,6 +240,11 @@ class ForwardContext: ...@@ -240,6 +240,11 @@ class ForwardContext:
additional_kwargs: dict[str, Any] = field(default_factory=dict) additional_kwargs: dict[str, Any] = field(default_factory=dict)
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
enable_lightly_cp: bool = False
enable_lightly_cplb : bool = False
def __post_init__(self): def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
...@@ -273,6 +278,10 @@ def create_forward_context( ...@@ -273,6 +278,10 @@ def create_forward_context(
slot_mapping: dict[str, torch.Tensor] | None = None, slot_mapping: dict[str, torch.Tensor] | None = None,
additional_kwargs: dict[str, Any] | None = None, additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
enable_lightly_cp: bool = False,
enable_lightly_cplb: bool = False
): ):
if vllm_config.compilation_config.fast_moe_cold_start: if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None: if vllm_config.speculative_config is None:
...@@ -298,6 +307,10 @@ def create_forward_context( ...@@ -298,6 +307,10 @@ def create_forward_context(
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
skip_compiled=skip_compiled, skip_compiled=skip_compiled,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_lightly_cp=enable_lightly_cp,
enable_lightly_cplb=enable_lightly_cplb,
additional_kwargs=additional_kwargs or {}, additional_kwargs=additional_kwargs or {},
) )
...@@ -329,6 +342,10 @@ def set_forward_context( ...@@ -329,6 +342,10 @@ def set_forward_context(
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
enable_lightly_cp: bool = False,
enable_lightly_cplb: bool = False,
): ):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
...@@ -389,6 +406,10 @@ def set_forward_context( ...@@ -389,6 +406,10 @@ def set_forward_context(
slot_mapping, slot_mapping,
additional_kwargs, additional_kwargs,
skip_compiled, skip_compiled,
scatter_indexes_tensor,
gather_indexes_tensor,
enable_lightly_cp,
enable_lightly_cplb
) )
try: try:
......
...@@ -7,8 +7,12 @@ import torch ...@@ -7,8 +7,12 @@ import torch
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.distributed import (
tensor_model_parallel_all_gather,
)
@dataclass @dataclass
...@@ -184,8 +188,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -184,8 +188,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
enable_lightly_cp = get_forward_context().enable_lightly_cp
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
# if not use_fused_rms_rope_concat: # if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if enable_lightly_cp:
kv_c_normed = tensor_model_parallel_all_gather(
kv_c_normed.contiguous(), 0
)
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if enable_lightly_cplb and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c_normed = torch.index_select(kv_c_normed, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
...@@ -221,6 +243,20 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -221,6 +243,20 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to " "VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"expose 'cos_sin_cache'." "expose 'cos_sin_cache'."
) )
if enable_lightly_cp:
kv_c = tensor_model_parallel_all_gather(
kv_c.contiguous(), 0
)
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if enable_lightly_cplb and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c = torch.index_select(kv_c, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
......
...@@ -90,7 +90,7 @@ def sparse_attn_indexer( ...@@ -90,7 +90,7 @@ def sparse_attn_indexer(
) )
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.forward_context import get_forward_context
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -36,6 +37,9 @@ from .deepseek_v2 import ( ...@@ -36,6 +37,9 @@ from .deepseek_v2 import (
DeepseekV2MoE, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name, get_spec_layer_idx_from_weight_name,
) )
from vllm.distributed import (tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from .utils import maybe_prefix from .utils import maybe_prefix
from .interfaces import SupportsPP from .interfaces import SupportsPP
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -191,7 +198,28 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -191,7 +198,28 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()
previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
else:
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)
previous_hidden_states = torch.index_select(previous_hidden_states, 0, scatter_indexes_tensor)
if positions is not None:
positions = torch.index_select(positions, 0, scatter_indexes_tensor)
hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids, input_ids,
positions, positions,
previous_hidden_states, previous_hidden_states,
...@@ -199,6 +227,14 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -199,6 +227,14 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx, current_step_idx,
) )
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)
return hidden_states
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -46,6 +46,8 @@ from vllm.distributed import ( ...@@ -46,6 +46,8 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.forward_context import get_forward_context
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...@@ -181,6 +183,44 @@ class DeepseekAttention(nn.Module): ...@@ -181,6 +183,44 @@ class DeepseekAttention(nn.Module):
return output return output
def iqis_all_gather(
iqis: tuple[torch.Tensor, torch.Tensor],
tp_size: int | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert iqis is not None
iq_tensor, is_tensor = iqis
assert isinstance(iq_tensor, torch.Tensor)
assert isinstance(is_tensor, torch.Tensor)
assert iq_tensor.dtype == torch.int8, f"iq_tensor dtype is {iq_tensor.dtype}"
assert is_tensor.dtype == torch.float32, f"is_tensor dtype is {is_tensor.dtype}"
assert iq_tensor.dim() == 2
assert is_tensor.dim() == 2
m_local, n = iq_tensor.shape
assert is_tensor.shape[0] == m_local, f"{is_tensor.shape[0]} != {iq_tensor.shape[0]}"
assert is_tensor.shape[1] == 1, f"is_tensor dim 1 ={is_tensor.shape[1]}"
iq_int8_2d = iq_tensor.view(torch.int8)
is_int8_2d = is_tensor.view(torch.int8)
combined_2d = torch.cat([iq_int8_2d, is_int8_2d], dim=1) # [m_local, n + 4]
if not combined_2d.is_contiguous():
combined_2d = combined_2d.contiguous()
combined_gathered = tensor_model_parallel_all_gather(combined_2d, dim=0)
split_idx = n
iq_gathered_int8 = combined_gathered[:, :split_idx].contiguous()
is_gathered_int8 = combined_gathered[:, split_idx:].contiguous()
iq_gathered = iq_gathered_int8.view(torch.int8)
assert iq_gathered.shape[0] == m_local * tp_size, f"iq_gathered dim0= {iq_gathered.shape[0]}, expected {m_local * tp_size}"
# is_gathered_int8 should be [m_local*tp_size, 4]
assert is_gathered_int8.shape[0] == m_local * tp_size, f"is_gathered_int8 dim0= {is_gathered_int8.shape[0]}, expected {m_local * tp_size}"
assert is_gathered_int8.shape[1] == 4, f"is_gathered_int8 dim1= {is_gathered_int8.shape[1]}"
is_gathered = is_gathered_int8.view(torch.float32)
return (iq_gathered, is_gathered)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -211,10 +251,85 @@ class DeepseekV2MLP(nn.Module): ...@@ -211,10 +251,85 @@ class DeepseekV2MLP(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, #reduce_results=reduce_results,
reduce_results=False,
disable_tp=is_sequence_parallel, disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
self.tp_size = get_tensor_model_parallel_world_size()
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self,
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
if iqis is not None and iqis[0] is not None and iqis[1] is not None:
iqis = iqis_all_gather(iqis, tp_size=self.tp_size)
else:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
if enable_lightly_cp:
x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0)
return x
elif self.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
class DeepseekV2SharedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj"
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now." f"Unsupported activation: {hidden_act}. Only silu is supported for now."
...@@ -311,7 +426,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -311,7 +426,7 @@ class DeepseekV2MoE(nn.Module):
else: else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2SharedMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -357,6 +472,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -357,6 +472,11 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(
hidden_states.contiguous(), 0
)
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -428,7 +548,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -428,7 +548,12 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None assert shared_output is not None
final_hidden_states += shared_output final_hidden_states += shared_output
if self.is_sequence_parallel: if enable_lightly_cp:
final_hidden_states = tensor_model_parallel_reduce_scatter(
final_hidden_states.contiguous(), 0
)
return final_hidden_states
elif self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
...@@ -759,6 +884,16 @@ class Indexer(nn.Module): ...@@ -759,6 +884,16 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA). # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
if enable_lightly_cplb and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor)
# we only quant q here since k quant is fused with cache insertion # we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim) q = q.view(-1, self.head_dim)
...@@ -825,7 +960,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -825,7 +960,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0 assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size self.num_local_heads = num_heads // tp_size if not \
vllm_config.parallel_config.enable_lightly_cp else self.num_heads
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
...@@ -859,6 +995,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -859,6 +995,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_b_proj", prefix=f"{prefix}.q_b_proj",
disable_tp=vllm_config.parallel_config.enable_lightly_cp
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -867,6 +1004,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -867,6 +1004,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_proj", prefix=f"{prefix}.q_proj",
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -875,6 +1013,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -875,6 +1013,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj", prefix=f"{prefix}.kv_b_proj",
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
...@@ -882,6 +1021,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -882,6 +1021,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
if config.rope_parameters["rope_type"] != "default": if config.rope_parameters["rope_type"] != "default":
...@@ -1118,6 +1258,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1118,6 +1258,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1.0 / self.routed_scaling_factor residual *= 1.0 / self.routed_scaling_factor
# Fully Connected # Fully Connected
enable_lightly_cp = get_forward_context().enable_lightly_cp
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
assert self.post_attention_layernorm.has_weight is True assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states, _i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
...@@ -1126,9 +1267,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1126,9 +1267,10 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input=update_hs update_input=update_hs
) )
new_resi = residual new_resi = residual
hidden_states = self.mlp(hidden_states, if enable_lightly_cp and isinstance(self.mlp, DeepseekV2MoE):
iqis=(_i_q, _i_s) hidden_states = self.mlp(hidden_states)
) else:
hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s))
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
...@@ -1225,6 +1367,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1225,6 +1367,9 @@ class DeepseekV2Model(nn.Module):
self.config = config self.config = config
self.device = current_platform.device_type self.device = current_platform.device_type
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
#添加判断,默认开启DSA #添加判断,默认开启DSA
force_disable_dsa = envs.VLLM_DISABLE_DSA force_disable_dsa = envs.VLLM_DISABLE_DSA
...@@ -1287,6 +1432,30 @@ class DeepseekV2Model(nn.Module): ...@@ -1287,6 +1432,30 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous()
if residual is not None:
residual_per_rank = torch.chunk(residual, chunks=self.tp_size, dim=0)
residual = residual_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
else:
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
hidden_states = torch.index_select(hidden_states, 0, scatter_indexes_tensor)
if residual is not None:
residual = torch.index_select(residual, 0, scatter_indexes_tensor)
if positions is not None:
positions = torch.index_select(positions, 0, scatter_indexes_tensor)
# Compute llama 4 scaling once per forward pass if enabled # Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None) llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
llama_4_scaling: torch.Tensor | None llama_4_scaling: torch.Tensor | None
...@@ -1307,11 +1476,21 @@ class DeepseekV2Model(nn.Module): ...@@ -1307,11 +1476,21 @@ class DeepseekV2Model(nn.Module):
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
residual = tensor_model_parallel_all_gather(residual.contiguous(), dim=0)
return IntermediateTensors( return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual} {"hidden_states": hidden_states, "residual": residual}
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)
return hidden_states return hidden_states
......
...@@ -282,6 +282,35 @@ class AttentionMetadata: ...@@ -282,6 +282,35 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata) T = TypeVar("T", bound=AttentionMetadata)
@dataclass
class CpCommonAttentionMetadata:
# sp related metadata
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
seq_lens: torch.Tensor
_seq_lens_cpu: torch.Tensor
num_actual_tokens: int
num_kv_actual_tokens: int
max_query_len: int
max_seq_len: int
num_reqs: int
req_ids: list[str]
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
_num_computed_tokens_cpu: torch.Tensor
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
def batch_size(self) -> int:
return self.seq_lens.shape[0]
@property
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@dataclass @dataclass
class CommonAttentionMetadata: class CommonAttentionMetadata:
...@@ -312,6 +341,14 @@ class CommonAttentionMetadata: ...@@ -312,6 +341,14 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
num_kv_actual_tokens: int | None = None
seq_indexes_list: list[int] | None = None
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
cp_common_metadata: CpCommonAttentionMetadata | None = None
enable_lightly_cp: bool = False
causal: bool = True causal: bool = True
# Needed by FastPrefillAttentionBuilder # Needed by FastPrefillAttentionBuilder
...@@ -396,6 +433,7 @@ class CommonAttentionMetadata: ...@@ -396,6 +433,7 @@ class CommonAttentionMetadata:
else None, else None,
num_reqs=num_actual_reqs, num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
num_kv_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
block_table_tensor=self.block_table_tensor[:num_actual_reqs], block_table_tensor=self.block_table_tensor[:num_actual_reqs],
......
...@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata): ...@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
max_seq_len: int max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
num_kv_actual_tokens: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
...@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_query_len=cm.max_query_len, max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len, max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens, num_actual_tokens=cm.num_actual_tokens,
num_kv_actual_tokens=cm.num_kv_actual_tokens,
query_start_loc=cm.query_start_loc, query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping, slot_mapping=cm.slot_mapping,
block_table=cm.block_table_tensor, block_table=cm.block_table_tensor,
...@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
return output.fill_(0) return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
num_kv_actual_toks = attn_metadata.num_kv_actual_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...] q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_kv_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...] k_pe = k_pe[:num_kv_actual_toks, ...]
assert self.topk_indices_buffer is not None assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks] topk_indices = self.topk_indices_buffer[:num_actual_toks]
......
...@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata: ...@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata:
max_seq_len: int max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
num_kv_actual_tokens: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# The dimension of the attention heads # The dimension of the attention heads
...@@ -438,6 +439,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -438,6 +439,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
max_query_len=common_attn_metadata.max_query_len, max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len, max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_tokens=common_attn_metadata.num_actual_tokens,
num_kv_actual_tokens=common_attn_metadata.num_kv_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc, query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128, head_dim=128,
......
...@@ -14,7 +14,7 @@ from vllm.config import ( ...@@ -14,7 +14,7 @@ from vllm.config import (
VllmConfig, VllmConfig,
get_layers_from_vllm_config, get_layers_from_vllm_config,
) )
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group, get_tensor_model_parallel_rank
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available ...@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
CpCommonAttentionMetadata,
) )
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import ( from vllm.v1.attention.backends.tree_attn import (
...@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import ( ...@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -76,7 +78,9 @@ class SpecDecodeBaseProposer: ...@@ -76,7 +78,9 @@ class SpecDecodeBaseProposer:
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
# The drafter can get longer sequences than the target model. # The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs max_batch_size = vllm_config.scheduler_config.max_num_seqs if not \
vllm_config.parallel_config.enable_lightly_cplb \
else vllm_config.scheduler_config.max_num_seqs * 2
self.max_num_tokens = ( self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
) )
...@@ -219,6 +223,28 @@ class SpecDecodeBaseProposer: ...@@ -219,6 +223,28 @@ class SpecDecodeBaseProposer:
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1) ).repeat(max_batch_size, 1)
self.scatter_indexes_tensor = None
self.gather_indexes_tensor = None
self.enable_lightly_cp = vllm_config.parallel_config.enable_lightly_cp
self.enable_lightly_cplb = self.enable_lightly_cp and vllm_config.parallel_config.enable_lightly_cplb
if self.enable_lightly_cp:
self.query_start_loc = CpuGpuBuffer(
max_batch_size + 1,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
self.seq_lens = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
def _get_positions(self, num_tokens: int): def _get_positions(self, num_tokens: int):
if self.uses_mrope: if self.uses_mrope:
return self.mrope_positions[:, :num_tokens] return self.mrope_positions[:, :num_tokens]
...@@ -270,6 +296,10 @@ class SpecDecodeBaseProposer: ...@@ -270,6 +296,10 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
return round_up(num_scheduled_tokens, tp_size)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
...@@ -309,6 +339,31 @@ class SpecDecodeBaseProposer: ...@@ -309,6 +339,31 @@ class SpecDecodeBaseProposer:
) )
) )
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
enable_lightly_cp = self.enable_lightly_cp and num_tokens > self.runner.lightly_cp_threshould
if enable_lightly_cp:
num_tokens_dp_padded = self._pad_for_mla_cp(num_tokens_dp_padded)
common_attn_metadata = self._prepare_cp_metadata(
num_reqs_padded=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
num_tokens=num_tokens,
block_table_gid_0=common_attn_metadata.block_table_tensor,
slot_mapping_gid_0=common_attn_metadata.slot_mapping,
query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=common_attn_metadata.query_start_loc_cpu,
seq_lens=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
self.scatter_indexes_tensor = common_attn_metadata.scatter_indexes_tensor
self.gather_indexes_tensor = common_attn_metadata.gather_indexes_tensor
assert self.runner is not None assert self.runner is not None
if self.attn_metadata_builder is None: if self.attn_metadata_builder is None:
...@@ -339,10 +394,6 @@ class SpecDecodeBaseProposer: ...@@ -339,10 +394,6 @@ class SpecDecodeBaseProposer:
assert draft_indexer_metadata is not None assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded num_tokens_dp_padded
) )
...@@ -387,6 +438,10 @@ class SpecDecodeBaseProposer: ...@@ -387,6 +438,10 @@ class SpecDecodeBaseProposer:
slot_mapping=self._get_slot_mapping( slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping num_input_tokens, common_attn_metadata.slot_mapping
), ),
scatter_indexes_tensor=self.scatter_indexes_tensor,
gather_indexes_tensor=self.gather_indexes_tensor,
enable_lightly_cp=self.enable_lightly_cp and num_tokens > self.runner.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
): ):
ret_hidden_states = self.model(**model_kwargs) ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple(): if not self.model_returns_tuple():
...@@ -463,6 +518,9 @@ class SpecDecodeBaseProposer: ...@@ -463,6 +518,9 @@ class SpecDecodeBaseProposer:
if batch_size_across_dp is not None: if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size batch_size_across_dp[self.dp_rank] = input_batch_size
if enable_lightly_cp:
common_attn_metadata = common_attn_metadata.cp_common_metadata
common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1 common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
...@@ -802,6 +860,7 @@ class SpecDecodeBaseProposer: ...@@ -802,6 +860,7 @@ class SpecDecodeBaseProposer:
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs, num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens, num_actual_tokens=total_num_tokens,
num_kv_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(), max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor, block_table_tensor=common_attn_metadata.block_table_tensor,
...@@ -988,6 +1047,104 @@ class SpecDecodeBaseProposer: ...@@ -988,6 +1047,104 @@ class SpecDecodeBaseProposer:
level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
total_num_drafts = self.cu_drafts_per_level[level + 1] total_num_drafts = self.cu_drafts_per_level[level + 1]
return draft_token_ids_list return draft_token_ids_list
def _prepare_cp_metadata(
self,
num_reqs_padded,
max_query_len,
max_seq_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
query_start_loc,
query_start_loc_cpu,
seq_lens,
seq_lens_cpu,
num_computed_tokens_cpu,
):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank()
cp_common_metadata = CpCommonAttentionMetadata(
query_start_loc=query_start_loc.clone(),
query_start_loc_cpu=query_start_loc_cpu.clone(),
seq_lens=seq_lens.clone(),
_seq_lens_cpu=seq_lens_cpu.clone(),
max_query_len=max_query_len,
max_seq_len=max_seq_len,
num_reqs=num_reqs_padded,
req_ids=self.runner.input_batch.req_ids,
num_actual_tokens=num_tokens,
num_kv_actual_tokens=num_tokens,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
_num_computed_tokens_cpu=num_computed_tokens_cpu
)
q_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
kv_lens_cpu = seq_lens_cpu
total_q_len = num_tokens
total_kv_len = num_tokens
(
total_q_len,
q_lens_cpu,
seq_count,
kv_lens_cpu,
local_req_ids,
scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes_list,
) = self.runner._distribute_tokens_to_cp_ranks(
total_q_len,
q_lens_cpu,
kv_lens_cpu,
tp_rank,
tp_size,
self.runner.input_batch.req_ids,
)
num_reqs = seq_count
cu_num_tokens = np.cumsum(q_lens_cpu)
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
q_acc_lens = self.query_start_loc.gpu[: num_reqs + 1]
q_acc_lens_cpu = self.query_start_loc.cpu[: num_reqs + 1]
max_q_len = max(q_acc_lens_cpu)
self.seq_lens.np[:num_reqs] = kv_lens_cpu
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
kv_lens = self.seq_lens.gpu[:num_reqs]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs]
max_kv_len = max(kv_lens_cpu)
num_computed_tokens_cpu = kv_lens_cpu - q_acc_lens_cpu[1:]
blk_table_tensor = block_table_gid_0[seq_indexes_list]
cm_base = CommonAttentionMetadata(
query_start_loc=q_acc_lens,
query_start_loc_cpu=q_acc_lens_cpu,
seq_lens=kv_lens,
_seq_lens_cpu=kv_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_q_len,
max_query_len=max_q_len,
max_seq_len=max_kv_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping_gid_0,
causal=True,
num_kv_actual_tokens=total_kv_len,
seq_indexes_list=seq_indexes_list,
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
)
return cm_base
def prepare_inputs( def prepare_inputs(
self, self,
......
...@@ -233,6 +233,10 @@ class BlockTable: ...@@ -233,6 +233,10 @@ class BlockTable:
def get_device_tensor(self, num_reqs: int) -> torch.Tensor: def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
"""Returns the device tensor of the block table.""" """Returns the device tensor of the block table."""
return self.block_table.gpu[:num_reqs] return self.block_table.gpu[:num_reqs]
def get_device_tensor_range(self, start_req: int, end_req: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table.gpu[start_req:end_req]
def get_cpu_tensor(self) -> torch.Tensor: def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table.""" """Returns the CPU tensor of the block table."""
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment