Commit d29c39ca authored by chenzk's avatar chenzk
Browse files

vllm kvprune wo:v1.1.0

parent f81ce56b
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
import torch.nn.functional as F
from torch import nn
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
# @torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
return F.silu(x) * y
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
from torch import nn
class Sampler(nn.Module):
def __init__(self):
super().__init__()
# @torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
temps = temperatures.view(-1)
scaled = logits.float()
greedy_mask = temps == 0.0
sample_mask = ~greedy_mask
if sample_mask.any():
temps_sample = temps[sample_mask].unsqueeze(-1) # [B_sample, 1]
scaled_sample = scaled[sample_mask].div(temps_sample) # temperature scaling
E = torch.empty_like(scaled_sample).exponential_(1).clamp_min_(1e-10).log()
scaled_sample = scaled_sample - E
scaled = scaled.clone()
scaled[sample_mask] = scaled_sample
return scaled.argmax(dim=-1)
import logging
from compactor_vllm.models.llama3 import LlamaForCausalLM
from compactor_vllm.models.qwen3 import Qwen3ForCausalLM
logger = logging.getLogger(__name__)
MODEL_REGISTRY = {
"llama": LlamaForCausalLM,
"qwen3": Qwen3ForCausalLM,
}
try:
from compactor_vllm.models.qwen3_moe import Qwen3MoeForCausalLM
except Exception as exc:
logger.debug("Skipping qwen3_moe registration due to import error: %s", exc)
else:
MODEL_REGISTRY["qwen3_moe"] = Qwen3MoeForCausalLM
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment