Commit f81ce56b authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.1

parent 2b7160c6
import torch
import torch.nn.functional as F
from torch import nn
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
# @torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
return F.silu(x) * y
import torch
import torch.distributed as dist
import torch.nn.functional as F
from compactor_vllm.utils.context import get_context
from torch import nn
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.weight = nn.Parameter(
torch.empty(self.num_embeddings_per_partition, embedding_dim)
)
self.weight.weight_loader = self.weight_loader
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor):
if self.tp_size > 1:
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask.unsqueeze(1) * y
dist.all_reduce(y)
return y
class ParallelLMHead(VocabParallelEmbedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
):
assert not bias
super().__init__(num_embeddings, embedding_dim)
def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight)
if self.tp_size > 1:
all_logits = (
[torch.empty_like(logits) for _ in range(self.tp_size)]
if self.tp_rank == 0
else None
)
dist.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits
This diff is collapsed.
import logging
from compactor_vllm.models.llama3 import LlamaForCausalLM
from compactor_vllm.models.qwen3 import Qwen3ForCausalLM
logger = logging.getLogger(__name__)
MODEL_REGISTRY = {
"llama": LlamaForCausalLM,
"qwen3": Qwen3ForCausalLM,
}
try:
from compactor_vllm.models.qwen3_moe import Qwen3MoeForCausalLM
except Exception as exc:
logger.debug("Skipping qwen3_moe registration due to import error: %s", exc)
else:
MODEL_REGISTRY["qwen3_moe"] = Qwen3MoeForCausalLM
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment