import torch import torch.distributed as dist import torch.nn.functional as F from vllm.kvprune.utils.context import get_context from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce from vllm.kvprune.utils.tp_utils import ( tensor_parallel_rank_for_sharding, tensor_parallel_world_size_for_sharding, ) from torch import nn class VocabParallelEmbedding(nn.Module): def __init__( self, num_embeddings: int, embedding_dim: int, ): super().__init__() self.tp_rank = tensor_parallel_rank_for_sharding() self.tp_size = tensor_parallel_world_size_for_sharding() assert num_embeddings % self.tp_size == 0 self.num_embeddings = num_embeddings self.num_embeddings_per_partition = self.num_embeddings // self.tp_size self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition self.weight = nn.Parameter( torch.empty(self.num_embeddings_per_partition, embedding_dim) ) self.weight.weight_loader = self.weight_loader def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data shard_size = param_data.size(0) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor): if self.tp_size > 1: mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx) x = mask * (x - self.vocab_start_idx) y = F.embedding(x, self.weight) if self.tp_size > 1: y = mask.unsqueeze(1) * y tensor_parallel_all_reduce(y) return y class ParallelLMHead(VocabParallelEmbedding): """LM head with TP vocab sharding. When embedded in a vLLM worker, logits must be gathered on the **tensor- parallel** process group (see :func:`~vllm.distributed.communication_op.tensor_model_parallel_gather`), not the default :func:`torch.distributed.gather` — otherwise shard order / group mismatch yields garbage logits and decoded gibberish. After gather, logits are truncated to ``org_vocab_size`` (HF tokenizer vocab), matching :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor` removal of padded vocabulary columns. """ def __init__( self, num_embeddings: int, embedding_dim: int, bias: bool = False, *, org_vocab_size: int | None = None, ): assert not bias super().__init__(num_embeddings, embedding_dim) # Original (unpadded) vocab size for logits truncation; defaults to num_embeddings. self.org_vocab_size = ( int(org_vocab_size) if org_vocab_size is not None else num_embeddings ) def forward(self, x: torch.Tensor): context = get_context() if context.is_prefill: cu = context.cu_seqlens_q last_indices = (cu[1:] - 1).to(torch.long) n_tok = x.shape[0] if n_tok > 0: last_indices = last_indices.clamp(min=0, max=n_tok - 1) x = x[last_indices].contiguous() logits = F.linear(x, self.weight) if self.tp_size > 1: logits = self._gather_logits_tp(logits) if logits is not None and logits.shape[-1] > self.org_vocab_size: logits = logits[..., : self.org_vocab_size] return logits def _gather_logits_tp(self, logits: torch.Tensor) -> torch.Tensor | None: try: from vllm.distributed.parallel_state import model_parallel_is_initialized from vllm.distributed.communication_op import ( tensor_model_parallel_gather, ) if model_parallel_is_initialized(): return tensor_model_parallel_gather(logits, dst=0, dim=-1) except Exception: pass all_logits = ( [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None ) dist.gather(logits, all_logits, 0) return torch.cat(all_logits, -1) if self.tp_rank == 0 else None