Commit d29c39ca authored by chenzk's avatar chenzk
Browse files

vllm kvprune wo:v1.1.0

parent f81ce56b
import math
from functools import lru_cache
from typing import Any
import torch
from torch import nn
def apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)
def rope_theta_from_hf_config(config: Any) -> float:
"""Match vLLM/HF: ``rope_theta`` may live only under ``rope_parameters`` in config.json."""
rp = getattr(config, "rope_parameters", None)
if isinstance(rp, dict) and "rope_theta" in rp:
return float(rp["rope_theta"])
return float(getattr(config, "rope_theta", 1_000_000.0))
class RotaryEmbedding(nn.Module):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
rope_scaling: tuple | None,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
inv_freq = 1.0 / (
base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)
)
if rope_scaling is not None:
(
rope_type,
factor,
low_freq_factor,
high_freq_factor,
original_max_position_embeddings,
) = rope_scaling
assert rope_type == "llama3"
old_context_len = original_max_position_embeddings
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / inv_freq
inv_freq_llama = torch.where(
wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
)
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(
wavelen > low_freq_wavelen
)
inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
# @torch.compile
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cache_len = self.cos_sin_cache.shape[0]
# CUDA graph capture forbids device→CPU sync (e.g. ``.item()``) inside the
# captured region; :meth:`ModelRunner.capture_cudagraph` runs decode with
# placeholder positions. Skip the range check while capturing; eager runs
# still validate.
_capturing = (
torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()
)
if positions.numel() > 0 and not _capturing:
pmax = int(positions.max().item())
pmin = int(positions.min().item())
if pmax >= cache_len or pmin < 0:
raise ValueError(
f"RoPE positions out of range: need 0 <= pos < {cache_len}, "
f"got min={pmin}, max={pmax}. "
"Shorten the prompt or increase max_model_len (and align vLLM "
"RoPE cos_sin_cache with tie_kvprune_rope_buffers_from_vllm)."
)
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
return query, key
@lru_cache(1)
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
rope_scaling: tuple | None = None,
):
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, rope_scaling
)
return rotary_emb
import torch
from torch import nn
class Sampler(nn.Module):
def __init__(self):
super().__init__()
# @torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
temps = temperatures.view(-1)
scaled = logits.float()
greedy_mask = temps == 0.0
sample_mask = ~greedy_mask
if sample_mask.any():
temps_sample = temps[sample_mask].unsqueeze(-1) # [B_sample, 1]
scaled_sample = scaled[sample_mask].div(temps_sample) # temperature scaling
E = torch.empty_like(scaled_sample).exponential_(1).clamp_min_(1e-10).log()
scaled_sample = scaled_sample - E
scaled = scaled.clone()
scaled[sample_mask] = scaled_sample
return scaled.argmax(dim=-1)
import torch
import triton
import triton.language as tl
@triton.jit
def _masked_index_select_kernel(
X_ptr,
IDX_ptr,
OUT_ptr,
N,
stride_xn,
stride_xh,
stride_ob,
stride_oh,
):
b = tl.program_id(0) # which output row (0..B-1)
h = tl.program_id(1)
idx = tl.load(IDX_ptr + b) # int32
valid = (idx >= 0) & (idx < N)
out_ptrs = OUT_ptr + b * stride_ob + h * stride_oh
if not valid:
tl.store(out_ptrs, 0)
else:
x_ptrs = X_ptr + idx * stride_xn + h * stride_xh
vals = tl.load(x_ptrs)
tl.store(out_ptrs, vals)
def masked_index_select_triton_dim0(
input: torch.Tensor, index: torch.Tensor
) -> torch.Tensor:
"""
X: [N, H] : contiguous in the H dimension
b_m: [B] int32/int64 on same device; out-of-range -> zeros)
Returns: [B, H]
"""
assert input.ndim == 2 and index.ndim == 1
N, H = input.shape
B = index.numel()
out = torch.empty((B, H), dtype=input.dtype, device=input.device)
_masked_index_select_kernel[(B, H)](
input,
index,
out,
N,
input.stride(0),
input.stride(1),
out.stride(0),
out.stride(1),
)
return out
@triton.jit
def _masked_index_copy_kernel(
DST_ptr,
IDX_ptr,
SRC_ptr,
N,
stride_dn,
stride_dh,
stride_sb,
stride_sh,
):
b = tl.program_id(0)
h = tl.program_id(1)
idx = tl.load(IDX_ptr + b)
valid = (idx >= 0) & (idx < N)
if valid:
src_ptrs = SRC_ptr + b * stride_sb + h * stride_sh
dst_ptrs = DST_ptr + idx * stride_dn + h * stride_dh
tl.store(dst_ptrs, tl.load(src_ptrs))
def masked_index_copy_triton_dim0(
dst: torch.Tensor, index: torch.Tensor, src: torch.Tensor
):
"""
In-place: dst.index_copy_(0, index, src) but masked:
- rows with index[b] < 0 or >= dst.shape[0] are skipped (no write).
Shapes:
dst: [N, H]
src: [B, H]
index: [B]
"""
assert dst.ndim == 2 and src.ndim == 2 and index.ndim == 1
N, H = dst.shape
B, Hs = src.shape
assert Hs == H and index.numel() == B
_masked_index_copy_kernel[(B, H)](
dst,
index,
src,
N,
dst.stride(0),
dst.stride(1),
src.stride(0),
src.stride(1),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from vllm.kvprune.models.llama3 import LlamaForCausalLM
from vllm.kvprune.models.qwen3 import Qwen3ForCausalLM
logger = logging.getLogger(__name__)
MODEL_REGISTRY = {
"llama": LlamaForCausalLM,
"qwen3": Qwen3ForCausalLM,
}
try:
from vllm.kvprune.models.qwen3_moe import Qwen3MoeForCausalLM
except Exception as exc:
logger.warning("Disabling qwen3_moe due to import error: %s", exc)
else:
MODEL_REGISTRY["qwen3_moe"] = Qwen3MoeForCausalLM
import os
from glob import glob
import torch
import tqdm
from safetensors import safe_open
from torch import nn
from transformers import LlamaConfig
from vllm.kvprune.compression import (
CompressionMethod,
apply_postrope_compression,
apply_prerope_compression,
)
from vllm.kvprune.layers.activation import SiluAndMul
from vllm.kvprune.layers.attention import Attention
from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
from vllm.kvprune.layers.layernorm import RMSNorm
from vllm.kvprune.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.kvprune.layers.rotary_embedding import get_rope
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.tp_utils import tensor_parallel_world_size_for_sharding
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: dict | None = None,
) -> None:
super().__init__()
tp_size = tensor_parallel_world_size_for_sharding()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
if rope_scaling is not None:
rope_scaling_tuple = (
rope_scaling["rope_type"],
rope_scaling["factor"],
rope_scaling["low_freq_factor"],
rope_scaling["high_freq_factor"],
rope_scaling["original_max_position_embeddings"],
)
else:
rope_scaling_tuple = None
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling_tuple,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
context = get_context()
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
scores = None
if context.is_prefill and context.do_compression:
scores = apply_prerope_compression(q, k, v, context)
q, k = self.rotary_emb(positions, q, k)
if context.is_prefill and context.do_compression:
cc = context.compression_context
if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
# 关键:注入 wo_weight 到 compression_context
wo_raw = self.o_proj.weight
hidden_size, _ = wo_raw.shape
Hq, D = self.num_heads, self.head_dim
cc.wo_weight = (
wo_raw.transpose(0, 1)
.contiguous()
.view(Hq, D, hidden_size)
.to(dtype=torch.float32)
)
scores = apply_postrope_compression(q, k, v, scores, context)
o = self.attn(q, k, v, scores)
output = self.o_proj(o.flatten(1, -1))
return output
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
mlp_bias: bool,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=mlp_bias,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=mlp_bias,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
) -> None:
super().__init__()
self.self_attn = LlamaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
rope_theta=getattr(config, "rope_theta", 500000.0),
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
mlp_bias=config.mlp_bias,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class LlamaForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.model = LlamaModel(config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_vocab_size=config.vocab_size,
)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)
def load_model(
self,
path: str,
*,
use_tqdm: bool = False,
) -> None:
all_shards = glob(os.path.join(path, "*.safetensors"))
for file in (
tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
weight_tensor = f.get_tensor(weight_name)
is_loaded = False
# Load packed modules
for k in self.packed_modules_mapping:
if k in weight_name:
v, shard_id = self.packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = self.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, weight_tensor, shard_id)
is_loaded = True
break
# Load other modules
if not is_loaded:
param = self.get_parameter(weight_name)
weight_loader = getattr(
param,
"weight_loader",
lambda p, loaded_weight: p.data.copy_(loaded_weight),
)
weight_loader(param, weight_tensor)
is_loaded = True
assert is_loaded, f"Weight {weight_name} not loaded"
import os
from glob import glob
import torch
import tqdm
from safetensors import safe_open
from torch import nn
from transformers import Qwen3Config
from vllm.kvprune.compression import (
CompressionMethod,
apply_postrope_compression,
apply_prerope_compression,
)
from vllm.kvprune.layers.activation import SiluAndMul
from vllm.kvprune.layers.attention import Attention
from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
from vllm.kvprune.layers.layernorm import RMSNorm
from vllm.kvprune.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.kvprune.layers.rotary_embedding import get_rope, rope_theta_from_hf_config
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.tp_utils import tensor_parallel_world_size_for_sharding
class Qwen3Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
) -> None:
super().__init__()
tp_size = tensor_parallel_world_size_for_sharding()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
context = get_context()
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
scores = None
if context.is_prefill and context.do_compression:
scores = apply_prerope_compression(q, k, v, context)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k)
if context.is_prefill and context.do_compression:
cc = context.compression_context
if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
# 关键:注入 wo_weight 到 compression_context
wo_raw = self.o_proj.weight
hidden_size, _ = wo_raw.shape
Hq, D = self.num_heads, self.head_dim
cc.wo_weight = (
wo_raw.transpose(0, 1)
.contiguous()
.view(Hq, D, hidden_size)
.to(dtype=torch.float32)
)
scores = apply_postrope_compression(q, k, v, scores, context)
o = self.attn(q, k, v, scores)
output = self.o_proj(o.flatten(1, -1))
return output
class Qwen3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class Qwen3DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
rope_theta = rope_theta_from_hf_config(config)
rs = getattr(config, "rope_scaling", None)
rope_scaling_tuple: tuple | None = rs if isinstance(rs, tuple) else None
self.self_attn = Qwen3Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=head_dim,
rope_theta=rope_theta,
rope_scaling=rope_scaling_tuple,
)
self.mlp = Qwen3MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen3Model(nn.Module):
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, config: Qwen3Config) -> None:
super().__init__()
self.model = Qwen3Model(config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_vocab_size=config.vocab_size,
)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)
def load_model(
self,
path: str,
*,
use_tqdm: bool = False,
) -> None:
all_shards = glob(os.path.join(path, "*.safetensors"))
for file in (
tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
weight_tensor = f.get_tensor(weight_name)
is_loaded = False
# Load packed modules
for k in self.packed_modules_mapping:
if k in weight_name:
v, shard_id = self.packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = self.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, weight_tensor, shard_id)
is_loaded = True
break
# Load other modules
if not is_loaded:
param = self.get_parameter(weight_name)
weight_loader = getattr(
param,
"weight_loader",
lambda p, loaded_weight: p.data.copy_(loaded_weight),
)
weight_loader(param, weight_tensor)
is_loaded = True
assert is_loaded, f"Weight {weight_name} not loaded"
import os
from glob import glob
import torch
import tqdm
from safetensors import safe_open
from torch import nn
from transformers import Qwen3MoeConfig
from vllm.kvprune.compression import (
CompressionMethod,
apply_postrope_compression,
apply_prerope_compression,
)
from vllm.kvprune.layers.activation import SiluAndMul
from vllm.kvprune.layers.attention import Attention
from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
from vllm.kvprune.layers.layernorm import RMSNorm
from vllm.kvprune.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.kvprune.layers.moe import (
MergedColumnParallelTritonFusedMoeLinear,
RowParallelTritonFusedMoeLinear,
)
from vllm.kvprune.layers.rotary_embedding import get_rope, rope_theta_from_hf_config
from vllm.kvprune.triton_kernels.routing import routing
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.tp_utils import (
tensor_parallel_rank_for_sharding,
tensor_parallel_world_size_for_sharding,
)
class Qwen3MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
sliding_window: int | None = None,
) -> None:
super().__init__()
tp_size = tensor_parallel_world_size_for_sharding()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
context = get_context()
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
scores = None
if context.is_prefill and context.do_compression:
scores = apply_prerope_compression(q, k, v, context)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k)
if context.is_prefill and context.do_compression:
cc = context.compression_context
if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
# 关键:注入 wo_weight 到 compression_context
wo_raw = self.o_proj.weight
hidden_size, _ = wo_raw.shape
Hq, D = self.num_heads, self.head_dim
cc.wo_weight = (
wo_raw.transpose(0, 1)
.contiguous()
.view(Hq, D, hidden_size)
.to(dtype=torch.float32)
)
scores = apply_postrope_compression(q, k, v, scores, context)
o = self.attn(q, k, v, scores)
output = self.o_proj(o.flatten(1, -1))
return output
class Qwen3MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class Qwen3MoeTritonSparseMoeBlock(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
num_experts_per_tok: int,
norm_topk_prob: bool,
hidden_act: str,
) -> None:
super().__init__()
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.norm_topk_prob = norm_topk_prob
self.hidden_size = hidden_size
self.moe_intermediate_size = intermediate_size
self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False)
self.gate_up_proj = MergedColumnParallelTritonFusedMoeLinear(
hidden_size, [intermediate_size] * 2, num_experts
)
self.down_proj = RowParallelTritonFusedMoeLinear(
intermediate_size, hidden_size, num_experts
)
self.act_fn = SiluAndMul()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = hidden_states
if x.numel() == 0:
return x
logits = self.gate(x)
rdata, gather_indx, scatter_indx = routing(
logits,
self.num_experts_per_tok,
simulated_ep=1, # single device, replicated experts
)
x = self.gate_up_proj(x, routing_data=rdata, gather_indx=gather_indx)
x = self.act_fn(x)
x = self.down_proj(
x, routing_data=rdata, scatter_indx=scatter_indx, gammas=rdata.gate_scal
)
return x
class Qwen3MoeBlock(Qwen3MoeTritonSparseMoeBlock):
pass
class Qwen3MoeRMSNorm(RMSNorm):
pass
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
layer_idx: int,
) -> None:
super().__init__()
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
rope_theta = rope_theta_from_hf_config(config)
rs = getattr(config, "rope_scaling", None)
rope_scaling_tuple: tuple | None = rs if isinstance(rs, tuple) else None
self.self_attn = Qwen3MoeAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
head_dim=head_dim,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
rope_theta=rope_theta,
rope_scaling=rope_scaling_tuple,
sliding_window=config.sliding_window,
)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoeBlock(
num_experts=config.num_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
num_experts_per_tok=config.num_experts_per_tok,
norm_topk_prob=config.norm_topk_prob,
hidden_act=config.hidden_act,
)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = Qwen3MoeRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Qwen3MoeRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3MoeModel(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[
Qwen3MoeDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
position_ids,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class Qwen3MoeForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: Qwen3MoeConfig,
) -> None:
super().__init__()
self.model = Qwen3MoeModel(config)
self.num_experts = config.num_experts
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_vocab_size=config.vocab_size,
)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, position_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)
def load_model(
self,
path: str,
*,
use_tqdm: bool = False,
) -> None:
rank = tensor_parallel_rank_for_sharding()
device = torch.cuda.current_device() if torch.cuda.is_available() else rank
all_shards = glob(os.path.join(path, "*.safetensors"))
for file in (
tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
):
with safe_open(file, "pt", f"cuda:{device}") as f:
for weight_name in f.keys():
weight_tensor = f.get_tensor(weight_name)
is_expert = "mlp.experts" in weight_name
is_loaded = False
# Process experts params name
if is_expert:
mlp_module_name, expert_module_name = weight_name.split(
".experts."
)
expert_idx = int(expert_module_name.split(".")[0])
proj_name = expert_module_name.replace(f"{expert_idx}.", "")
weight_name = f"{mlp_module_name}.{proj_name}"
# Load packed modules
for k in self.packed_modules_mapping:
if k in weight_name:
v, shard_id = self.packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = self.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
if is_expert:
weight_loader(
param, weight_tensor, expert_idx, shard_id
)
else:
weight_loader(param, weight_tensor, shard_id)
is_loaded = True
break
# Load other modules
if not is_loaded:
param = self.get_parameter(weight_name)
weight_loader = getattr(
param,
"weight_loader",
lambda p, lw: p.data.copy_(lw, non_blocking=True),
)
if is_expert:
weight_loader(param, weight_tensor, expert_idx)
else:
weight_loader(param, weight_tensor)
is_loaded = True
assert is_loaded, f"Weight {weight_name} not loaded"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Triton kernel utilities (matmul_ogs, MoE, topk, …) plus KV-facing entrypoints.
For KV pruning attention/store, see also ``vllm.kvprune.attention`` and
``vllm.kvprune.kv_cache``.
"""
from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
from vllm.kvprune.kv_cache.store_kv_cache import (
decode_store_kv,
prefill_store_all_kv,
prefill_store_topk_kv,
)
__all__ = [
"causal_sparse_varlen_with_cache",
"decode_store_kv",
"prefill_store_all_kv",
"prefill_store_topk_kv",
]
import torch
from .compaction_details._masked_compaction import _masked_compaction
from .tensor import Bitmatrix
def compaction(yv, yi, bitmask, sentinel=-1):
"""
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
Only the elements whose index appears among the active bits of *bitmask*
are kept; the rest are replaced by *sentinel*. Kept elements preserve
their original left-to-right order.
Parameters
----------
yv : torch.Tensor, shape (B, K)
Values tensor.
yi : torch.Tensor, shape (B, K), dtype torch.long
Integer indices (0 ≤ index < 32) associated with *yv*.
bitmask : torch.Tensor, shape (B,) **or** (B, 32)
Per-row mask of active indices. See the in-place version for details.
sentinel : int, default -1
Value written into dropped positions of the returned tensors.
Returns
-------
(yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
New tensors with the same dtype/device as the inputs.
"""
n_rows, n_cols = yi.shape
ret_yv = torch.empty_like(yv)
ret_yi = torch.empty_like(yi)
if isinstance(bitmask, Bitmatrix):
bitmask = bitmask.storage.data
_masked_compaction[(n_rows,)](
yv,
yi,
bitmask,
bitmask.stride(0),
bitmask.stride(1), # inputs
ret_yv,
ret_yi, # outputs
sentinel, # sentinel
K=n_cols, # constants
)
return ret_yv, ret_yi
def compaction_torch(
yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1
):
"""
reference implementation of `masked_compact`
"""
B, K = yi.shape
device = yi.device
# Expand bitmask to a boolean matrix of active bits (B, 32)
w = 1 << torch.arange(32, device=device, dtype=bitmask.dtype)
bits = (bitmask.unsqueeze(-1) & w) != 0
mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
# For every yi element decide whether it should be kept
keep = mask.gather(1, yi.long())
# Build a stable permutation that brings all "keep" items forward
# False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
order = (~keep).to(torch.int).argsort(dim=1, stable=True)
# Re‑order tensors according to above permutation
yi_sorted = yi.gather(1, order)
yv_sorted = yv.gather(1, order)
# fill relevant positions with sentinel
keep_sorted = keep.gather(1, order)
yi_sorted[~keep_sorted] = sentinel
yv_sorted[~keep_sorted] = sentinel
return yv_sorted, yi_sorted
import triton
import triton.language as tl
@triton.jit
def _masked_compaction(
Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr
):
pid_m = tl.program_id(0)
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
div = yi // 32
rem = yi % 32
active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
active_flags = active_bits.to(tl.int1)
rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
write_indx = exc_cumsum + rev_arange
yv = tl.where(active_flags, yv, sentinel)
yi = tl.where(active_flags, yi, sentinel)
tl.store(RetYv + pid_m * K + write_indx, yv)
tl.store(RetYi + pid_m * K + write_indx, yi)
# isort: off
# fmt: off
from dataclasses import dataclass
import itertools
import sys
import torch
import triton
from enum import Enum, auto
import math
# utilities
from vllm.kvprune.triton_kernels import target_info
from vllm.kvprune.triton_kernels.numerics import InFlexData, OutFlexData
from vllm.kvprune.triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from vllm.kvprune.triton_kernels.target_info import is_cuda
# details
from .matmul_ogs_details._matmul_ogs import _matmul_ogs
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
from .specialize import specialize
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
@dataclass(frozen=True)
class FnSpecs:
name: str
fn: "triton.runtime.jit.JITFunction"
fn_arg_names: tuple[str]
fn_arg_do_not_specialize: tuple[str] = tuple()
@staticmethod
def default():
return FnSpecs("dflt", None, tuple())
@dataclass(frozen=True)
class FusedActivation:
specs: FnSpecs = FnSpecs.default()
fn_args: tuple[object] = tuple()
reduction_n: int = 1
@dataclass(frozen=True)
class Epilogue:
specs: FnSpecs = FnSpecs.default()
fn_arg_values_matmul: tuple[object] = tuple()
fn_arg_values_finalize: tuple[object] = tuple()
effective_itemsize: float = None
class FnName(Enum):
QUANTIZE_MXFP8 = auto()
EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
_kernels = dict()
def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
global _kernels
key = (fused_activation.name, epilogue.name)
if key in _kernels:
return _kernels[key]
spec_constants = {
"ACTIVATION_FN": fused_activation.fn,
"EPILOGUE_FN": epilogue.fn,
}
spec_tuples = {
"activation_fn_args": fused_activation.fn_arg_names,
"epilogue_fn_args": epilogue.fn_arg_names,
}
do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
import types
module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
sys.modules[module.__name__] = module
module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
do_not_specialize=do_not_specialize)
module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
do_not_specialize=do_not_specialize)
module._reduce_grouped = specialize(_reduce_grouped, module, spec_constants, spec_tuples,
do_not_specialize=do_not_specialize)
_kernels[key] = module
return module
# -----------------------------------------------------------------------------
# Matrix Multiplication + Outer Gather/Scatter
# -----------------------------------------------------------------------------
def can_overflow_int32(tensor: torch.Tensor):
max_int32 = (1 << 31) - 1
offset = 0
for i in range(tensor.ndim):
offset += (tensor.shape[i] - 1) * tensor.stride(i)
return offset > max_int32
def should_upcast_indices(*args):
return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
# ---------------------
# Numerics
# ---------------------
# fmt: off
@dataclass(frozen=True)
class FlexCtx:
lhs_data: InFlexData = InFlexData()
rhs_data: InFlexData = InFlexData()
out_data: OutFlexData = OutFlexData()
@dataclass
class PrecisionConfig:
max_num_imprecise_acc: int = None
allow_tf32: bool = True
flex_ctx: FlexCtx = FlexCtx()
acc_scale: int = 1.0
flexpoint_saturate_inf: bool = False
report_quantization_err_fn: callable = None
act_scale: Tensor | None = None
weight_scale: Tensor| None = None
out_scale: Tensor | None = None
out_dtype: torch.dtype = None
enforce_bitwise_invariance: bool = False
# TODO: merge in opt_flags
def get_swap_xw(precision_config, opt_flags):
if target_info.cuda_capability_geq(10, 0):
return precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
return False
# ---------------------
# Allocation
# ---------------------
@dataclass
class MatmulAllocation:
device: str
output: tuple[tuple[int], torch.dtype]
scratchpads: dict[str, tuple]
def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags):
# ---- output ------
N = w.shape[-1]
# by default - M is number of rows in the activations
M = x.shape[-2]
# if the activations are gathered, then M is number of gather indices
if gather_indx is not None:
M = gather_indx.src_indx.shape[0]
# final output
if routing_data.n_expts_act == 1 or scatter_indx is None:
y_rows = M
else:
Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
y_rows = Mc
batch_dim = x.shape[0] if x.ndim == 3 else 1
out_shape = (batch_dim, y_rows, N // fused_activation.reduction_n)
out_dtype = precision_config.out_dtype or x.dtype
output = (out_shape, out_dtype)
# ---- scratchpad -----#
scratchpad = dict()
if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), scratch_out_dtype)
if "matmul" in scratchpad and precision_config.out_scale is not None:
scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
return MatmulAllocation(x.device, output, scratchpad)
def apply_allocation(allocation: MatmulAllocation, output):
ret = dict()
if output is None:
output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
else:
assert output.shape == allocation.output[0]
ret["output"] = output[None, :, :]
ret["scratchpad"] = {
k: torch.empty(v[0], device=allocation.device, dtype=v[1])
for k, v in allocation.scratchpads.items()
}
return ret
# -----------------------------------------------------------------------------
# Canonicalize
# -----------------------------------------------------------------------------
# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
# we can canonicalize storages to make the implementation more uniform
def _canonicalize_storage(storage, out_ndim, flex_data):
assert out_ndim >= storage.data.ndim
# Need to use as_strided instead of view because for a tensor with
# shape[-2] == 1 can have ambuiguity related to col-wise. Fo example,
# > t = torch.randn(2, 5, 1).mT
# > t_view = t.view(t.shape)
# > t.stride(), t_view.stride()
# ((5, 1, 1), (5, 5, 1))
# Our check t_view is col-wise fails since t_view.stride(-2) != 1
# This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py
new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape)
new_storage_view = storage.data.view(new_storage_shape)
new_storage_stride = [new_storage_view.stride(0)] * (out_ndim - storage.data.ndim) + list(storage.data.stride())
new_storage_data = storage.data.as_strided(new_storage_shape, new_storage_stride)
if flex_data is not None:
new_storage_data = flex_data.reinterpret(new_storage_data)
return Storage(new_storage_data, storage.layout)
#
def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_mx_scale: torch.Tensor,
fused_activation, epilogue,
x_flex: InFlexData | None = None,
out_flex: OutFlexData | None = None, x_mx_scale: torch.Tensor | None = None,
out_dtype: bool = None, flexpoint_saturate_inf: bool = False):
"""
In-place grouped row reduction.
Arguments
- x: Tensor[AnyFloat] of shape [(num_groups * K), N]
- indx: Tensor[Int] of shape [num_groups, K]
Description
For each group g in [0, num_groups), this routine sums the K rows of `x`
specified by `indx[g, :]` and overwrites the row corresponding to the first
valid (non-negative) index with the per-group sum. Accumulation is performed
in float32 for numerical stability, and the result is written back in the
dtype of `x`.
Behavior and edge cases
- Invalid (-1) entries are skipped during accumulation and do not generate
memory traffic. If a group has no valid entries, nothing is written for
that group.
- Reduction is performed tile-by-tile along the N dimension within a single
kernel launch (persistent along N) to minimize launch overhead.
Performance notes
- Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x),
plus index reads. With no invalid entries, this becomes (K + 1) reads/writes
of length N per group.
Returns
- The input tensor `x` (modified in place).
"""
if indx is None and x.shape[0] == 1:
return x.squeeze(0), None
if indx is not None:
num_groups = indx.shape[0]
else:
num_groups = x.shape[-2]
if x_flex is None:
x_flex = InFlexData()
if out_flex is None:
out_flex = OutFlexData()
K = 1 if indx is None else indx.shape[1]
out_dtype = x.dtype if out_dtype is None else out_dtype
assert x.shape[-1] % fused_activation.reduction_n == 0
BLOCK_N = 512
# Resolve scalar flex scales (may be None)
x_expected_scale = None if x_flex is None else x_flex.scale
out_expected_scale = None if out_flex is None else out_flex.expected_scale
out_actual_scale = None if out_flex is None else out_flex.actual_scale
out_checksum_scale = None if out_flex is None else out_flex.checksum_scale
# Resolve MXFP output scale row stride
stride_mxb = 0 if x_mx_scale is None else x_mx_scale.stride(0)
stride_mxs = 0 if x_mx_scale is None else x_mx_scale.stride(1)
stride_omxs = 0 if out_mx_scale is None else out_mx_scale.stride(0)
kernels = get_kernels(epilogue.specs, fused_activation.specs)
kernels._reduce_grouped[(num_groups, )](
x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), #
x_expected_scale, # scalar input scale
out_flex.reinterpret(out), out.stride(1), out.stride(2), #
out_expected_scale, out_actual_scale, out_checksum_scale, indx, #
x.shape[0], x.shape[-1], #
x_mx_scale, stride_mxb, stride_mxs, #
out_mx_scale, stride_omxs, #
*fused_activation.fn_args, fused_activation.reduction_n,
*epilogue.fn_arg_values_finalize,
HAS_IN_MX_SCALE=x_mx_scale is not None, HAS_OUT_MX_SCALE=out_mx_scale is not None,
FLEXPOINT_SATURATE_INF=flexpoint_saturate_inf, #
BLOCK_N=BLOCK_N, K=K, #
num_warps=1, #
)
return out, out_mx_scale
# -----------------------------------------------------------------------------
# Triton Implementation
# -----------------------------------------------------------------------------
def matmul_ogs_set_idle_sms(num_idle_sms):
"""
persistent kernels will leave `num_idle_sms` idle
"""
update_opt_flags_constraints({"idle_sms": num_idle_sms})
def matmul_ogs(x, w, bias,
routing_data: RoutingData | None = None,
gather_indx: GatherIndx | None = None,
scatter_indx: ScatterIndx | None = None,
precision_config: PrecisionConfig | None = None,
betas: torch.Tensor | None = None,
gammas: torch.Tensor | None = None,
out_alpha: float | None = None,
y: torch.Tensor | None = None,
fused_activation: FusedActivation | None = None,
epilogue: Epilogue | None = None,
):
"""
Y[:, :] = 0.
for e in num_experts:
Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
"""
is_input_batched = x.ndim == 3
if is_input_batched:
assert gather_indx is None, "gather not supported in batched mode"
assert scatter_indx is None, "scatter not supported in batched mode"
assert routing_data is None, "routing not supported in batched mode"
assert w.ndim == 3 and w.shape[0] == x.shape[0]
# canonicalize inputs
if precision_config is None:
precision_config = PrecisionConfig()
if fused_activation is None:
fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
if epilogue is None:
epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
if routing_data is None:
routing_data = RoutingData(None, None, max(1, w.shape[0]), 1)
# unpack scales
w_scale = precision_config.weight_scale
w_has_mx = w_scale is not None
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
if not isinstance(w, Tensor):
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
w = wrap_torch_tensor(w, dtype=dtype)
if w_scale is not None and not isinstance(w_scale, Tensor):
w_scale = Tensor(w_scale)
if w_scale is not None:
w_scale.storage.data = w_scale.data.view(torch.uint8)
w_scale.dtype = torch.uint8
x_scale = precision_config.act_scale
x_has_mx = x_scale is not None
if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp"
if x_scale is not None and not isinstance(x_scale, Tensor):
x_scale = Tensor(x_scale)
if not isinstance(x, Tensor):
x = Tensor(x, dtype=x.dtype)
# determine shapes
has_gather = gather_indx is not None
has_scatter = scatter_indx is not None
is_ragged = routing_data.expt_hist is not None
M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
K, N = w.shape[-2:]
assert K == x.shape[-1]
if x.ndim == 3 and w.ndim == 3:
assert x.shape[0] == w.shape[0]
# compute optimization flags
out_dtype = precision_config.out_dtype or x.dtype
can_use_tma = x.numel() > 0 and x.storage.is_tma_compliant() and \
w.numel() > 0 and w.storage.is_tma_compliant() and \
(w_scale is None or w_scale.storage.is_tma_compliant())
# hopper w/ mxfp4 doesn't support TMA
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
)
if not can_use_fused_scatter and opt_flags.fused_scatter:
raise InapplicableConstraint("Fused scatter is not supported")
if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
# fused activation
matmul_fused_activation = fused_activation
reduce_fused_activation = FusedActivation()
if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation
# allocate output/scratchpad memory
allocation = init_allocation(x, w, precision_config, fused_activation,
routing_data, gather_indx, scatter_indx, opt_flags)
memory = apply_allocation(allocation, y)
# early exit
if batch_size * M * N == 0:
ret = memory["output"].squeeze(0)
if not is_input_batched:
ret = ret.squeeze(0)
return ret
# TMA descriptors require a global memory allocation
if opt_flags.is_persistent:
triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
# Intermediate tensors and postprocess kernels for each situation
has_scratchpad = "matmul" in memory["scratchpad"]
# Canonical output tensor (matmul scratchpad if present, otherwise final output tensor)
out_matmul = memory["scratchpad"].get("matmul", memory["output"])
out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data
# Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer
out_matmul_scale = precision_config.out_scale
if out_matmul_scale is not None:
out_matmul_scale = out_matmul_scale.data.view(torch.uint8)
if has_scratchpad and "mx_out_scale" in memory["scratchpad"]:
out_matmul_scale = memory["scratchpad"]["mx_out_scale"]
out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1
# matrix multiplication
flex = precision_config.flex_ctx
bias_stride = None if bias is None else bias.stride(0)
num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
# moe metadata
expt_data = routing_data.expt_data
block_m = opt_flags.block_m
expt_hist = None if expt_data is None else expt_data.hist
expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
# spmd grid
grid_m = triton.cdiv(M, opt_flags.block_m)
if expt_block_pid_map is not None:
grid_m = routing_data.n_blocks(M, opt_flags.block_m)
grid_n = triton.cdiv(N, opt_flags.block_n)
max_grid = batch_size * grid_m * grid_n * opt_flags.split_k
grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
# canonicalize storage
has_gather_tma = has_gather and target_info.has_tma_gather()
has_scatter_tma = opt_flags.fused_scatter and target_info.has_tma_gather()
y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if opt_flags.fused_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:]))
x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)
# create tma descriptor for x
x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k]
x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data
# create tma descriptor for y
y_has_tma = opt_flags.is_persistent and (has_scatter_tma or not opt_flags.fused_scatter)
block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.reduction_n
y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n]
y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data
# create tma descriptor for w
w_has_tma = opt_flags.is_persistent
w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
# create tma descriptor for w_scale
w_scale_tensor_or_tma = w_scale
w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale
# canonicalize strides
x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None)
out_matmul_scale_strides = (0, ) * (3 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
# launch kernel
kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs)
# When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
# (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
# is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
# w_transpose = w_storage.data.stride()[-1] != 1
w_transpose = w_storage.data.stride()[-2] == 1
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
*((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
*out_matmul_scale_strides[-3:],
x_tensor_or_tma, x_storage.data, *x_strides,
flex.lhs_data.scale,
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,
flex.rhs_data.scale,
w_scale_tensor_or_tma, *w_scale_strides,
bias, bias_stride,
x.shape[-2],
x.shape[-2] if routing_data.expt_hist is None else None,
N, K,
betas, gammas,
None if gather_indx is None else gather_indx.src_indx,
None if scatter_indx is None else scatter_indx.src_indx,
num_indx,
None if not opt_flags.fused_scatter else scatter_indx.dst_indx,
None if not opt_flags.fused_scatter else scatter_indx.dst_indx.shape[0],
expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
batch_size, grid_m, grid_n,
out_alpha,
*matmul_fused_activation.fn_args, matmul_fused_activation.reduction_n,
*epilogue.fn_arg_values_matmul,
routing_data.n_expts_tot, routing_data.n_expts_act,
precision_config.max_num_imprecise_acc,
precision_config.allow_tf32,
precision_config.flexpoint_saturate_inf,
flex.rhs_data.is_per_batch,
opt_flags.block_m,
opt_flags.block_n,
opt_flags.block_k,
opt_flags.group_m,
XCD_SWIZZLE=opt_flags.xcd_swizzle,
SWIZZLE_MX_VALUE=w.storage.layout.name,
SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name,
EPILOGUE_SUBTILE=opt_flags.epilogue_subtile,
SPLIT_K=opt_flags.split_k,
EVEN_K=K % opt_flags.block_k == 0,
W_CACHE_MODIFIER=opt_flags.w_cache_modifier,
TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt,
num_warps=opt_flags.num_warps,
num_stages=opt_flags.num_stages,
arch=opt_flags.arch,
UPCAST_INDICES=should_upcast_indices(x, w, out_matmul),
X_TMA_MODE=x_tma_mode,
Y_TMA_MODE=y_tma_mode,
SWAP_XW=get_swap_xw(precision_config, opt_flags),
IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name,
NUM_SMS = grid if opt_flags.is_persistent else 0,
**opt_flags.target_kernel_kwargs)
# Build grouped reduction inputs in a uniform way
group_indx = None if scatter_indx is None or opt_flags.fused_scatter else scatter_indx.src_indx.view(-1, routing_data.n_expts_act)
out_final, out_final_mx_scale = reduce_grouped(
out_matmul,
group_indx,
memory["output"].squeeze(0),
precision_config.out_scale,
reduce_fused_activation,
epilogue,
x_flex=InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale),
out_flex=precision_config.flex_ctx.out_data,
x_mx_scale=out_matmul_scale.squeeze(1) if out_matmul_has_mx else None,
out_dtype=memory["output"].dtype,
flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf,
)
if not is_input_batched:
out_final = out_final.squeeze(0)
if out_final_mx_scale is not None:
precision_config.out_scale = out_final_mx_scale
return out_final
# -----------------------------------------------------------------------------
# Reference Implementation
# -----------------------------------------------------------------------------
def matmul_ogs_torch(x, w, bias,
routing_data: RoutingData = None,
gather_indx: GatherIndx = None,
scatter_indx: ScatterIndx = None,
precision_config: PrecisionConfig = None,
betas = None,
gammas = None,
round_x = None, round_y = None,
):
is_input_batched = x.ndim == 3
assert x.dtype.itemsize > 1
assert w.dtype.itemsize > 1
if is_input_batched:
assert gather_indx is None, "gather not supported in batched mode"
assert scatter_indx is None, "scatter not supported in batched mode"
assert routing_data is None, "routing not supported in batched mode"
assert w.ndim == 3 and w.shape[0] == x.shape[0]
if round_x is None:
round_x = lambda x, idx: x
if round_y is None:
round_y = lambda x: x
if bias is not None and bias.ndim == 1:
bias = bias.view(1, *bias.shape)
if w.ndim == 2:
w = w.view(1, *w.shape)
if x.ndim == 2:
x = x.view(1, *x.shape)
if routing_data is None:
routing_data = RoutingData(None, None, w.shape[0], 1)
n_expts_act = routing_data.n_expts_act
# memory offsets
if routing_data.n_expts_tot > 1 and not is_input_batched:
sizes = routing_data.expt_hist
off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32)
off[1:] = torch.cumsum(sizes, 0)
offs = list(itertools.pairwise(off))
else:
offs = [[0, x.shape[1]] for _ in range(w.shape[0])]
# compute
n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0]
y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype)
for i, (lo, hi) in enumerate(offs):
if gather_indx is None:
idx = torch.arange(lo, hi, device=x.device)
else:
idx = gather_indx.src_indx[lo:hi] // n_expts_act
batch = i if is_input_batched else 0
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
w[i].float())
if bias is not None:
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
if gammas is not None:
out *= gammas[lo:hi, None]
y[batch, lo:hi, :] = round_y(out)
if not is_input_batched:
y = y.view(y.shape[1], y.shape[2])
if scatter_indx is None:
return y
# accumulate output from all experts
n_rows = y.shape[0] // n_expts_act
out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device)
for i, (lo, hi) in enumerate(offs):
dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act
msk = dst_idx != -1
out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float()
return out
import torch
import triton
import triton.language as tl
# -----------------------------------------------------------------------------
# Utilities
# -----------------------------------------------------------------------------
@triton.constexpr_function
def get_scaled_dot_format_string(dtype: tl.dtype):
mapping = {
tl.float16: "fp16",
tl.bfloat16: "bf16",
tl.uint8: "e2m1",
tl.float8e4nv: "e4m3",
tl.float8e5: "e5m2",
}
return mapping[dtype]
@triton.jit
def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr):
"""
Swizzle the program id based on integer XCD_SWIZZLE.
This is useful for reording how blocks are ordered. A scheduler may, for example,
assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2.
This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment
becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to
the same hardware unit.
"""
# Number of pids per group in the new arrangement
pids_per_group = domain_size // XCD_SWIZZLE
extra_pid_groups = domain_size % XCD_SWIZZLE
# Compute current current and local pid within the group
group = pid % XCD_SWIZZLE
local_pid = pid // XCD_SWIZZLE
# Calculate new pid based on the new grouping
new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid
return new_pid
@triton.jit
def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr):
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
tl.assume(group_size >= 0)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
return pid_m, pid_n
def make_matmul_repr(base_name, order):
def matmul_repr(specialization):
signature = specialization.signature
constants = specialization.constants
reorder = lambda L: [L[i] for i in order]
layout = lambda stride: "N" if stride in constants else "T"
def convert_dtype(dtype):
if "tensordesc" in dtype:
ret = convert_dtype(dtype.split("<")[1].split("[")[0])
return ret
elif "u8" in dtype:
return "mxfp4"
elif dtype[0] == "*":
return dtype[1:]
else:
return dtype
dtypes = "x".join(
[convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])]
)
layouts = "".join(
[
f"{layout(i)}"
for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])
]
)
blocks = "x".join(
[f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]]
)
# mode = []
# if "GatherIndx" not in constants:
# mode += ['g']
# if "ScatterSrcIndx" not in constants:
# mode += ['s']
# suffix = "" if not mode else "_o" + (''.join(mode))
# if base_name.startswith("_p"):
# suffix += "_ptma"
return f"{base_name}_{layouts}_{dtypes}_{blocks}"
return matmul_repr
def matmul_launch_metadata(grid, kernel, args):
from ..proton_opts import launch_metadata_allow_sync
ret = dict()
M, N, K = args["M"], args["N"], args["K"]
Y, X, W = args["YPtr"], args["XPtr"], args["WPtr"]
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
hist = args["ExptHist"]
if hist is not None:
# If annotation is given, use that to generate name for profiling.
if tokens_per_expt is not None:
n_rows = f"{tokens_per_expt}*"
elif launch_metadata_allow_sync():
n_rows = int(hist.float().mean())
else:
n_rows = "unknown"
if launch_metadata_allow_sync():
n_tokens = float(hist.sum())
n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (
hist > 0
).sum()
elif tokens_per_expt is not None:
n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
# This may not be totally correct (e.g., we might not be using all experts)
# but it's better than nothing.
n_w_bytes = W.numel() * W.element_size()
else:
n_tokens = None
n_w_bytes = 0
# If annotation is given, use that to generate name for profiling.
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
else:
n_tokens = None
n_w_bytes = W.numel() * W.element_size()
repr = (
lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}"
)
nbits = X.dtype.itemsize * 8
batch_repr = ""
if "batch_size" in args and args["batch_size"] > 1:
batch_repr = repr("B", args["batch_size"]) + ", "
ret["name"] = (
f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}"
)
ep_subtile = args["EPILOGUE_SUBTILE"]
if ep_subtile is not None and ep_subtile > 1:
ret["name"] += f" ep/{ep_subtile}"
if hist is not None and n_tokens is None:
return ret # Don't fill metadata because we can't compute them properly.
fM = M if M is not None else n_tokens
fK = K if K is not None else n_tokens
ret[f"flops{nbits}"] = 2.0 * fM * N * fK
gindx = args.get("GatherIndx", None)
# sindx = args.get("WriteBackIndx", None)
n_x_bytes = X.numel() * X.element_size()
n_y_bytes = Y.numel() * Y.element_size()
if hist is not None:
assert n_tokens is not None
n_expts_act = args["N_EXPTS_ACT"]
if (gindx is not None) and launch_metadata_allow_sync():
# recreate inverse GatherIndx.
dst = torch.full_like(gindx, -1)
idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
mask = gindx != -1
dst[gindx[mask]] = idx[mask]
n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
else:
n_read_rows = n_tokens
n_x_bytes = n_read_rows * X.shape[-1] * X.element_size()
n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size()
ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes)
return ret
# isort: off
# fmt: off
import triton
import triton.language as tl
from vllm.kvprune.triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
from vllm.kvprune.triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
from vllm.kvprune.triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
from vllm.kvprune.triton_kernels.tensor_details.layout_details.cdna4_scale import unswizzle_mx_scale_cdna4
from vllm.kvprune.triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
from vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
@triton.jit
def _zero_masked_rows(
pid_m, pid_n,
Y, stride_y_m, stride_y_n,
N,
ScatterSrcIndx, num_idxs,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M)
offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
mask_n = offs_n < N
mask = (src_idx == -1)[:, None] & mask_n[None, :]
tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask)
_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
def _matmul_ogs(
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
YExpectedScale, YActualScale, YChecksumScale,
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
XScale,
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
WScale,
WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
B, stride_b_e, # Bias
NRows, M, N, K, # shapes
# expt data
Betas, Gammas,
GatherIndx,
ScatterSrcIndx, num_idxs,
WriteBackIndx, writeback_size,
ExptHist, ExptOffs, ExptOffsSum, ExptData,
# true grid size
batch_size, grid_m, grid_n,
# Out scale
out_alpha,
# fused activation function
ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
# epilogue transform
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
# MoE config
N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
# precision config
MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
FLEXPOINT_SATURATE_INF: tl.constexpr,
PER_BATCH_SCALE: tl.constexpr,
# optimization config
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
# One of ["HOPPER", "BLACKWELL", None]
SWIZZLE_MX_VALUE: tl.constexpr,
# One of ["HOPPER", "BLACKWELL", None]
SWIZZLE_MX_SCALE: tl.constexpr,
EPILOGUE_SUBTILE: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
W_CACHE_MODIFIER: tl.constexpr,
NUM_SMS: tl.constexpr,
X_TMA_MODE: tl.constexpr,
Y_TMA_MODE: tl.constexpr,
TOKENS_PER_EXPT_FOR_ANNOTATION=None,
UPCAST_INDICES: tl.constexpr = False,
SWAP_XW: tl.constexpr = False,
IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
tl.assume(stride_y_k >= 0)
tl.assume(stride_y_z >= 0)
tl.assume(stride_y_m >= 0)
tl.assume(stride_y_n >= 0)
tl.assume(stride_x_z >= 0)
tl.assume(stride_x_m >= 0)
tl.assume(stride_x_k >= 0)
tl.assume(stride_w_e >= 0)
tl.assume(stride_w_k >= 0)
tl.assume(stride_w_n >= 0)
if stride_w_mx_e is not None:
tl.assume(stride_w_mx_e >= 0)
if stride_w_mx_k is not None:
tl.assume(stride_w_mx_k >= 0)
if stride_w_mx_n is not None:
tl.assume(stride_w_mx_n >= 0)
if B is not None:
tl.assume(stride_b_e >= 0)
tl.assume(batch_size >= 0)
tl.assume(grid_m >= 0)
tl.assume(grid_n >= 0)
is_w_microscaled: tl.constexpr = WMxScale is not None
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
if is_w_microscaled:
w_type: tl.constexpr = W.dtype.element_ty
is_mxfp4: tl.constexpr = w_type == tl.uint8
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
"mx_weight_ptr must be uint8 or fp8")
tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
else:
tl.static_assert(SWIZZLE_MX_VALUE is None)
tl.static_assert(SWIZZLE_MX_SCALE is None)
is_x_microscaled: tl.constexpr = XMxScale is not None
if is_x_microscaled:
x_type: tl.constexpr = X.dtype.element_ty
tl.static_assert(is_w_microscaled)
tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
is_out_microscaled: tl.constexpr = stride_y_mx_z is not None
OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
yN = N // ACTIVATION_REDUCTION_N
pid = tl.program_id(0)
if ExptOffsSum is not None and XCD_SWIZZLE > 1:
# Determine how much padding there is on the expert data. This allows us to
# know the true grid size and avoid processing padding tiles.
padding_m = grid_m - tl.load(ExptOffsSum)
else:
padding_m: tl.constexpr = 0
HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32
unpadded_m = grid_m - padding_m
tl.assume(unpadded_m >= 0)
total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K
if padding_m > 0 and pid >= total_actual_tiles:
tl.device_assert(batch_size == 0)
pid_mn = pid - total_actual_tiles
if pid_mn < padding_m * grid_n:
pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M)
# set masked out rows to 0
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
_zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
return
# swizzle program ids
pid_emnk = pid
if XCD_SWIZZLE != 1:
pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE)
pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K)
pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K)
pid_k = pid_mnk % SPLIT_K
pid_mn = pid_mnk // SPLIT_K
pid_m, pid_n = swizzle2d(pid_mn, unpadded_m, grid_n, GROUP_M)
# For split-k, advance to the output k slice
if SPLIT_K > 1:
Y += pid_k.to( index_type) * stride_y_k
if is_out_microscaled:
YActualScale += pid_k.to(index_type) * stride_x_mx_k
# set masked out rows to 0
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
_zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
# unpack expert data
if ExptData is None:
tl.static_assert(M is not None)
expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m
else:
tl.static_assert(M is None)
expt_data = tl.load(ExptData + pid_m)
if expt_data == -1:
return
expt_id = expt_data & 0x0000FFFF
block_id = expt_data >> 16
M = tl.load(ExptHist + expt_id)
start_m = tl.load(ExptOffs + expt_id)
start_z = 0
expt_id, block_id = expt_id.to(index_type), block_id.to(index_type)
start_m, start_z = start_m.to(index_type), start_z.to(index_type)
pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type)
# A pointers
offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M)
X += start_z * stride_x_z
if GatherIndx is None:
X += start_m * stride_x_m
else:
GatherIndx += start_m
# no needs to bounds-check here because `offs_x_m` wraps around M dim
offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT
offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K)
XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k
# TODO: refactor if/else when triton front end improves
if is_w_microscaled:
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
tl.static_assert(not is_x_microscaled)
# We have pack 2 fp4 values in a byte but we divide the dimension by 2
# when swizzling
W_K_DIVISOR: tl.constexpr = 1
W_K_MULTIPLIER: tl.constexpr = 2
W_N_DIVISOR: tl.constexpr = 4
else:
# We have pack 2 fp4 values in a byte
W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1
W_K_MULTIPLIER: tl.constexpr = 1
W_N_DIVISOR: tl.constexpr = 1
if W_TRANSPOSE:
# When weight is transposed, 2 fp4 values are packed per Byte along
# the contiguous dimension, K.
PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR
else:
# When weight is not transposed, fp4 values are *not* packed along
# the contiguous dimension, N.
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_K_DIVISOR
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
WMxScale += expt_id * stride_w_mx_e
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
# TODO: support non W_TRANSPOSE with blackwell swizzling
tl.static_assert(W_TRANSPOSE)
tl.static_assert(BLOCK_N % 128 == 0)
tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0)
PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 128
stride_scale_k: tl.constexpr = 1
elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
# TODO: support non W_TRANSPOSE with Hopper swizzling
tl.static_assert(W_TRANSPOSE)
n_warps: tl.constexpr = tl.extra.cuda.num_warps()
tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0)
tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0)
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32
stride_scale_k = stride_w_mx_k
elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
tl.static_assert(stride_w_mx_k is not None)
tl.static_assert(stride_w_mx_n is not None)
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE
stride_scale_k = stride_w_mx_k
else:
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K
SCALE_BLOCK_N: tl.constexpr = BLOCK_N
stride_scale_k = stride_w_mx_k
offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
# K dimension must be the last dimension for the scales
offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
else:
WMxScalePtrs = None
offs_k_scale = None
W_K_DIVISOR: tl.constexpr = 1
W_K_MULTIPLIER: tl.constexpr = 1
W_N_DIVISOR: tl.constexpr = 1
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N
# B pointers
offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
if is_x_microscaled:
XMxScale += start_z.to(index_type) * stride_x_mx_z
if GatherIndx is None:
XMxScale += start_m * stride_x_mx_m
offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
else:
XMxScalePtrs = None
offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W)
W += expt_id * stride_w_e
WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n)
# compute output
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)):
if EVEN_K:
mask_k = tl.full([BLOCK_K], True, dtype=tl.int1)
mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1)
if is_w_microscaled and SWIZZLE_MX_SCALE is None:
mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1)
if is_x_microscaled:
mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
else:
mask_k = offs_k < k
mask_k_w = offs_w_k < ((k // (W_K_DIVISOR if W_TRANSPOSE else 1)) * W_K_MULTIPLIER)
if is_w_microscaled and SWIZZLE_MX_SCALE is None:
mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k
if is_x_microscaled:
mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < k
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER)
if is_w_microscaled:
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
if is_x_microscaled:
x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :])
elif x_format == "fp16" or x_format == "bf16":
x_scales: tl.constexpr = None
else:
# Scale of 1 in E8M0 format
x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8)
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
w_scales = unswizzle_mx_scale_bw(tl.load(WMxScalePtrs))
elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
# Handshake with the swizzling code
num_warps: tl.constexpr = tl.extra.cuda.num_warps()
w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps)
elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
w_scales = unswizzle_mx_scale_cdna4(tl.load(WMxScalePtrs), BLOCK_N, MX_SCALE_BLOCK_K)
else:
w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :])
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
# Handshake with the swizzling code
tl.static_assert(x_format == "bf16")
tl.static_assert(w_format == "e2m1")
w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1)
tl.static_assert(w.dtype == tl.bfloat16)
acc = acc.trans()
x = x.trans()
# w = w.trans()
acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
acc = acc.trans()
else:
rhs_k_pack: tl.constexpr = W_TRANSPOSE or not is_w_microscaled or W_K_DIVISOR != 2
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True, rhs_k_pack=rhs_k_pack)
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
WMxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_w_mx_k
else:
WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k
if is_x_microscaled:
XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k
else:
acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k
WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k
# bias + scale
offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
mask_m = offs_m < M
mask_n = offs_y_n < N
if B is not None:
BPtrs = B + expt_id * stride_b_e + offs_y_n
if pid_k == 0:
bias = tl.load(BPtrs, mask=mask_n, other=0)
else:
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
else:
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
if Betas is not None:
betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0)
else:
betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
if Gammas is not None:
gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0)
else:
gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
# flexpoint
x_scale = load_scale(XScale)
if PER_BATCH_SCALE:
w_scale = load_scale(WScale + expt_id)
else:
w_scale = load_scale(WScale)
acc *= x_scale * w_scale
acc = acc + bias[None, :] * betas[:, None]
if out_alpha is not None:
acc *= out_alpha
if ACTIVATION_FN is not None:
out = ACTIVATION_FN(acc, *activation_fn_args)
tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N)
mask_n = offs_y_n < yN
else:
tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
out = acc
out *= gammas[:, None]
# write-back
Y += start_z.to(index_type) * stride_y_z
if WriteBackIndx is not None:
WriteBackIndx += start_m
dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1)
mask_m = mask_m & (dst_idx != -1)
offs_y_m = dst_idx
else:
Y += start_m * stride_y_m
offs_y_m = offs_m
YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n
mask = mask_m[:, None] & mask_n[None, :]
if is_out_microscaled:
MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
tl.static_assert(EPILOGUE_FN is not None)
out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N)
mask_n_scale = offs_y_n_scale < N_MX_BLOCK
YActualScale += start_z.to(index_type) * stride_y_mx_z
if WriteBackIndx is None:
YActualScale += start_m * stride_y_mx_m
YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
else:
YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
else:
out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)
if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8:
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty)
tl.store(YPtrs, out, mask=mask)
# isort: off
# fmt: off
import torch
import triton
import triton.language as tl
from triton.tools.ragged_tma import load_ragged, store_ragged
from vllm.kvprune.triton_kernels import target_info
from vllm.kvprune.triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
float_to_flex,
load_scale,
nan_propagating_absmax_reduce,
compute_scale,
)
from vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
@triton.constexpr_function
def cuda_capability_geq(major, minor):
return target_info.cuda_capability_geq(major, minor)
@triton.constexpr_function
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
if isinstance(tensor_or_desc, tl.tensor):
return tensor_or_desc.dtype.element_ty
elif isinstance(tensor_or_desc, tl.tensor_descriptor):
return tensor_or_desc.dtype
else:
raise ValueError(f"Invalid type: {type(tensor_or_desc)}")
@triton.jit
def _load_tile_attrs(
tile_id, num_tiles, grid_m, grid_n, padding_m,
M, ExptData, ExptHist, ExptOffs,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr,
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr):
# unpack and swizzle program ids
pid_emnk = tile_id
if XCD_SWIZZLE != 1:
pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE)
pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
if SPLIT_K > 1:
pid_k = pid_mnk % SPLIT_K
pid_mn = pid_mnk // SPLIT_K
else:
pid_k: tl.constexpr = 0
pid_mn = pid_mnk
pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
# unpack expert data
if ExptData is None:
tl.static_assert(M is not None)
expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1
else:
tl.static_assert(M is None)
expt_data = tl.load(ExptData + pid_m)
expt_id = expt_data & 0x0000FFFF
block_id = expt_data >> 16
eM = tl.load(ExptHist + expt_id)
start_m = tl.load(ExptOffs + expt_id)
start_z = 0
off_m = BLOCK_M * block_id
off_n = BLOCK_N * pid_n
return expt_id, start_z, start_m, eM, off_m, off_n, pid_k
@triton.jit
def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
mask = mask & (offs < writeback_size)
offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1)
mask = offs != -1
return (offs, mask)
_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2])
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
def _p_matmul_ogs(
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
YExpectedScale, YActualScale, YChecksumScale,
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
XScale,
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
WScale,
MxScale, stride_mx_e, stride_mx_k, stride_mx_n,
B, stride_b_e, # Bias
NRows, M, N, K, # shapes
# expt data
Betas, Gammas,
GatherIndx,
ScatterSrcIndx, num_idxs,
WriteBackIndx, writeback_size,
ExptHist, ExptOffs, ExptOffsSum, ExptData,
# true grid size
batch_size, grid_m, grid_n,
# Out scale
out_alpha,
# fused activation function
ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
# epilogue transform
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
# MoE config
N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
# precision config
MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
FLEXPOINT_SATURATE_INF: tl.constexpr,
PER_BATCH_SCALE: tl.constexpr,
# optimization config
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
# NYI: Must be None
SWIZZLE_MX_VALUE: tl.constexpr,
# One of ["BLACKWELL", None]
SWIZZLE_MX_SCALE: tl.constexpr,
EPILOGUE_SUBTILE: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
W_CACHE_MODIFIER: tl.constexpr,
NUM_SMS: tl.constexpr,
X_TMA_MODE: tl.constexpr,
Y_TMA_MODE: tl.constexpr,
TOKENS_PER_EXPT_FOR_ANNOTATION=None,
UPCAST_INDICES:tl.constexpr=False,
SWAP_XW: tl.constexpr = False,
IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
# tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")
# why is this faster than using host-side tensor descriptor?!
if Y_TMA_MODE is not None:
Y = tl.make_tensor_descriptor(YPtr, Y.shape, Y.strides[:-1] + (1,), Y.block_shape)
is_microscaled_format: tl.constexpr = MxScale is not None
tl.static_assert(not is_microscaled_format or W_TRANSPOSE, "NYI. Non-transposed mxfp4 weights")
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
if is_microscaled_format:
w_type: tl.constexpr = get_dtype(W)
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
"mx_weight_ptr must be uint8")
tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8")
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
# We have pack 2 fp4 values in a byte
W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
else:
W_PACK_DIVISOR: tl.constexpr = 1
MX_SCALE_BLOCK_K: tl.constexpr = 1
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
tl.static_assert(SWIZZLE_MX_SCALE is None)
if ExptOffsSum is not None:
# Determine how much padding there is on the expert data. This allows us to
# know the true grid size and avoid processing padding tiles.
padding_m = grid_m - tl.load(ExptOffsSum)
else:
padding_m: tl.constexpr = 0
index_type: tl.constexpr = tl.int64
USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None
HAS_SCATTER: tl.constexpr = WriteBackIndx is not None
HAS_GATHER: tl.constexpr = GatherIndx is not None
USE_GATHER_TMA: tl.constexpr = HAS_GATHER and X_TMA_MODE == "dense"
USE_SCATTER_TMA: tl.constexpr = HAS_SCATTER and Y_TMA_MODE == "dense"
if EPILOGUE_SUBTILE is None:
SUBTILE_FACTOR: tl.constexpr = 1
else:
SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE
EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR
OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N
yN = N // ACTIVATION_REDUCTION_N
# set masked out rows to 0
if HAS_SCATTER and N_EXPTS_ACT == 1:
# Iterate with reversed pids so that later pids will get more tiles if the number of
# tiles isn't evenly divisible by the number of SMs.
# The main loop after this iterates in the forward direction such that earlier
# pids get more tiles if the number of tiles isn't evenly divisible.
# This helps balance the work across the SMs.
for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS):
pid_k = pid_mnk % SPLIT_K
pid_mn = pid_mnk // SPLIT_K
pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M)
z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32)
offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0])
offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1])
src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
YPtrs = YPtr + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
mask_n = offs_n < yN
mask = (src_idx == -1)[:, None] & mask_n[None, :]
tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask)
k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K)
num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
# If true, do not share loop-carried variables between the prologue and the
# epilogue to enable better pipelining with mmav5
INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0)
# start negative; will be incremented at the top of the loop
if INDEPENDENT_EPILOGUE:
tile_id1 = tl.program_id(0) - NUM_SMS
# Keep track of local max for updating flexpoint scales.
THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True):
expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(
tile_id, num_tiles, grid_m, grid_n, padding_m,
M, ExptData, ExptHist, ExptOffs,
BLOCK_M, BLOCK_N, SPLIT_K,
GROUP_M, XCD_SWIZZLE)
# Base pointers and offsets.
if X_TMA_MODE is None:
XBase = X + start_z.to(index_type) * stride_x_z
offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k
if SPLIT_K > 1:
offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k
if USE_GATHER_TMA:
offs_m = off_m + tl.arange(0, BLOCK_M)
mask_m = offs_m < (M if M is not None else eM)
if ExptData is None:
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m)
# Bump rows to account for the Z offset.
offs_x_m += start_z * (stride_x_z // stride_x_m)
offs_x_m = tl.where(mask_m, offs_x_m, -1)
else:
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
elif X_TMA_MODE is None:
tl.static_assert(HAS_GATHER)
offs_m = off_m + tl.arange(0, BLOCK_M)
if M is not None:
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M)
else:
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
# no needs to bounds-check here because `offs_m` wraps around M dim
offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32)
for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
# --- load x ---
if USE_GATHER_TMA:
x = X.gather(offs_x_m, off_k)
elif X_TMA_MODE == "dense":
x = X.load([start_z, start_m + off_m, off_k])
x = x.reshape(BLOCK_M, BLOCK_K)
elif X_TMA_MODE == "ragged":
x = load_ragged(X, start_m, eM, [start_z, off_m, off_k], ragged_dim=1)
x = x.reshape(BLOCK_M, BLOCK_K)
else:
tl.static_assert(X_TMA_MODE is None)
XPtrs = XBase + offs_x_m + offs_x_k
XBase += BLOCK_K * SPLIT_K * stride_x_k
mask_k = tl.arange(0, BLOCK_K) < K - off_k
if EVEN_K:
if SPLIT_K > 1:
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
else:
x = tl.load(XPtrs)
else:
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
# --- load w ---
if W_TRANSPOSE:
w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:]).T
else:
w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:])
# --- load w_scale ---
if is_microscaled_format:
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
if x_format == "fp16" or x_format == "bf16":
x_scales: tl.constexpr = None
else:
x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8)
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128)
w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
w_scales = unswizzle_mx_scale_bw(w_scales)
else:
w_scales = MxScale.load([expt_id, off_k_mx, off_n])
w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T
# --- update accumulator ---
if is_microscaled_format:
if SWAP_XW:
acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
else:
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
else:
if SWAP_XW:
acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
else:
acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
if INDEPENDENT_EPILOGUE:
tile_id1 += NUM_SMS
expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs(
tile_id1, num_tiles, grid_m, grid_n, padding_m,
M, ExptData, ExptHist, ExptOffs,
BLOCK_M, BLOCK_N, SPLIT_K,
GROUP_M, XCD_SWIZZLE)
else:
tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM
off_m1, off_n1, pid_k1 = off_m, off_n, pid_k
offs_m = off_m1 + tl.arange(0, BLOCK_M)
mask_m = offs_m < (M if M is not None else eM1)
if USE_SCATTER_TMA:
offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
if SPLIT_K > 1:
# Compute the split k offset in number of rows, and add it to offs_y_m.
# This allows us to write to the correct slice in the output tensor while using
# a 2D TMA scatter.
tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m))
split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m)
offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m)
elif Y_TMA_MODE is None:
tl.static_assert(HAS_SCATTER)
offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
else:
offs_y_m = start_m1 + offs_m
MASK_ACC = False if USE_GATHER_TMA else USE_FLEXPOINT_SCALE
# bias + scale
offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
mask_n = offs_y_n < N
if B is not None:
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
if pid_k1 == 0:
bias = tl.load(BPtrs, mask=mask_n, other=0)
else:
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
else:
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
if Betas is not None:
betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
else:
betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
if Gammas is not None:
gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0)
else:
gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
x_scale = load_scale(XScale)
if PER_BATCH_SCALE:
w_scale = load_scale(WScale + expt_id1)
else:
w_scale = load_scale(WScale)
accs = (acc,)
biases = (bias,)
if SUBTILE_FACTOR >= 2:
acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
accs = (acc0, acc1)
bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
biases = (bias0, bias1)
if SUBTILE_FACTOR >= 4:
acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
accs = (acc00, acc01, acc10, acc11)
bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
biases = (bias00, bias01, bias10, bias11)
tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
tl.static_assert(len(accs) == SUBTILE_FACTOR)
for a_i in tl.static_range(len(accs)):
acc_tile = accs[a_i]
acc_tile *= x_scale * w_scale
if SWAP_XW:
acc_tile = acc_tile.T
acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
if out_alpha is not None:
acc_tile *= out_alpha
if ACTIVATION_FN is not None:
out = ACTIVATION_FN(acc_tile, *activation_fn_args)
tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
else:
tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
out = acc_tile
out *= gammas[:, None]
if MASK_ACC:
out = tl.where(mask_m[:, None], out, 0.0)
# Flexpoint
out_view = tl.reshape(out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
out = float_to_flex(
out, YExpectedScale,
None, # ActualScale: local absmax is tracked and updated after the loop
YChecksumScale,
None, # mask: out is manually masked to 0
YPtr, FLEXPOINT_SATURATE_INF
)
if EPILOGUE_FN is not None:
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
out = out.to(YPtr.dtype.element_ty)
if USE_SCATTER_TMA:
# Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
# there shouldn't be any other negative values.
offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True)
Y.scatter(out, offs_y_m, out_off_n)
elif Y_TMA_MODE == "dense":
out = tl.reshape(out, [1] + out.shape)
off_kz = pid_k * batch_size + start_z1
Y.store([off_kz, off_m1, out_off_n], out)
elif Y_TMA_MODE == "ragged":
out = tl.reshape(out, [1] + out.shape)
store_ragged(Y, start_m1, eM1, [pid_k, off_m1, out_off_n], out, ragged_dim=1)
else:
tl.static_assert(Y_TMA_MODE is None)
offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
mask_n = offs_y_n < yN
YPtrs = YPtr + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n
mask = mask_m[:, None] & mask_n[None, :]
tl.store(YPtrs, out, mask=mask)
# Update the flexpoint scales
if YActualScale is not None:
tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), YPtr), sem="relaxed")
_per_device_alloc_fns = {}
def get_per_device_per_stream_alloc_fn(device):
if device not in _per_device_alloc_fns:
_per_stream_tensors = {}
def alloc_fn(size: int, alignment: int, stream):
assert alignment == 128
if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size:
_per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8)
_per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
return _per_stream_tensors[stream]
_per_device_alloc_fns[device] = alloc_fn
return _per_device_alloc_fns[device]
from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
float_to_flex,
load_scale,
)
from vllm.kvprune.triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn
import triton
import triton.language as tl
@triton.jit
def _reduce_grouped(
X,
stride_xb: tl.uint64,
stride_xm: tl.uint64,
stride_xn, #
XScale, # input scalar flex scale
Out,
stride_om: tl.uint64,
stride_on, # output tensor
OutExpectedScale,
OutActualScale,
OutChecksumScale, # output scalar flex scales
InIndx,
B,
N, #
XMxScale,
stride_mxb: tl.uint64,
stride_mxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
OutMxScale,
stride_omxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
# fused activation function
ACTIVATION_FN: tl.constexpr,
activation_fn_args,
ACTIVATION_REDUCTION_N: tl.constexpr,
# epilogue transform
EPILOGUE_FN: tl.constexpr,
epilogue_fn_args,
#
HAS_IN_MX_SCALE: tl.constexpr,
HAS_OUT_MX_SCALE: tl.constexpr,
FLEXPOINT_SATURATE_INF: tl.constexpr,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_t = tl.program_id(0)
BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
# persistent along N: single program on N, iterate tiles of size BLOCK_N
start = pid_t * K
# load indices into a tuple
if InIndx is None:
indxs = (pid_t,)
else:
indxs = ()
for i in tl.static_range(0, K):
indxs = indxs + (tl.load(InIndx + start + i),)
# determine first valid topk row
fi = indxs[(K - 1)]
for i in tl.static_range(K - 2, -1, -1):
fi = tl.where(indxs[i] != -1, indxs[i], fi)
# record overwritten row index (may be -1 if none)
XPtrs = X + tl.arange(0, BLOCK_N) * stride_xn
OutPtrs = Out + tl.arange(0, BLOCK_N_OUT) * stride_on
if HAS_IN_MX_SCALE:
XScalePtrs = XMxScale + tl.arange(0, BLOCK_N // 32) * stride_xn
if HAS_OUT_MX_SCALE:
OutScalePtrs = OutMxScale + tl.arange(0, BLOCK_N_OUT // 32) * stride_on
x_scale = load_scale(XScale)
for n_curr in tl.range(0, N, BLOCK_N, num_stages=4):
acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)
x_n_mask = tl.arange(0, BLOCK_N) < N - n_curr
x_n_mask_scale = tl.arange(0, BLOCK_N // 32) < tl.cdiv(N - n_curr, 32)
# accumulate contributions for this tile
for i in tl.static_range(0, K):
curr = tl.zeros([BLOCK_N], dtype=tl.float32)
# iterate over split_k partial values
for b in tl.range(0, B):
is_valid = indxs[i] != -1
x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb
vals = tl.load(x_row_ptr, mask=x_n_mask & is_valid, other=0.0)
vals = vals.to(tl.float32)
if HAS_IN_MX_SCALE:
scale_row_ptr = XScalePtrs + indxs[i] * stride_mxs + b * stride_mxb
scale = tl.load(
scale_row_ptr, mask=x_n_mask_scale & is_valid, other=0.0
)
scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
vals = vals.reshape([BLOCK_N // 32, 32])
vals = (scale[:, None] * vals).reshape([BLOCK_N])
curr += vals
# apply nonlinearity to split-k output
if ACTIVATION_FN is not None:
curr = ACTIVATION_FN(curr[None, :], *activation_fn_args)
curr = tl.reshape(curr, [curr.shape[-1]])
# update final accumulator
acc += curr
acc *= x_scale
# Compute per-32-col MXFP scales for this tile if requested
Nrem = (N - n_curr) // ACTIVATION_REDUCTION_N
out_n_mask = tl.arange(0, BLOCK_N_OUT) < Nrem
out_n_mask_scale = tl.arange(0, BLOCK_N_OUT // 32) < tl.cdiv(Nrem, 32)
if HAS_OUT_MX_SCALE:
acc, acc_scale = quantize_mxfp8_fn(acc[None, :], out_n_mask[None, :])
acc = tl.reshape(acc, [acc.shape[-1]])
acc_scale = tl.reshape(acc_scale, [acc_scale.shape[-1]])
# Convert to flexpoint output if configured (scalar scales)
acc = float_to_flex(
acc,
OutExpectedScale,
OutActualScale,
OutChecksumScale,
None,
Out,
FLEXPOINT_SATURATE_INF,
)
# write-back for this tile
out_ptr = OutPtrs + pid_t * stride_om
tl.store(out_ptr, acc, mask=out_n_mask)
if HAS_OUT_MX_SCALE:
out_scale_ptr = OutScalePtrs + pid_t * stride_omxs
tl.store(out_scale_ptr, acc_scale, mask=out_n_mask_scale)
XPtrs += BLOCK_N * stride_xn
OutPtrs += BLOCK_N_OUT * stride_on
if HAS_IN_MX_SCALE:
XScalePtrs += BLOCK_N // 32 * stride_xn
if HAS_OUT_MX_SCALE:
OutScalePtrs += BLOCK_N_OUT // 32 * stride_xn
# isort: off
# fmt: off
from dataclasses import dataclass
import triton
from vllm.kvprune.triton_kernels.target_info import get_cdna_version
import torch
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
@dataclass
class OptFlags:
block_m: int
block_n: int
block_k: int
num_warps: int
num_stages: int
group_m: int
xcd_swizzle: int
w_cache_modifier: str
split_k: int
is_persistent: bool
fused_scatter: bool
idle_sms: int
epilogue_subtile: int | None
arch: str
target_kernel_kwargs: dict
def __post_init__(self):
if self.fused_scatter and self.split_k != 1:
raise ValueError("Not supported")
def make_default_opt_flags_amd(
out_dtype,
lhs_dtype,
rhs_dtype,
precision_config,
m,
n,
k,
routing_data,
can_use_persistent_tma,
can_use_fused_scatter,
enforce_bitwise_invariance,
epilogue_effective_itemsize,
constraints,
):
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
# tokens per expert
if routing_data is None:
tokens_per_expt = m
elif routing_data.expected_tokens_per_expt is None:
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
else:
tokens_per_expt = routing_data.expected_tokens_per_expt
is_cdna4 = get_cdna_version() == 4
# block_m
if constraints.get("block_m", None):
block_m = constraints["block_m"]
elif enforce_bitwise_invariance:
block_m = 256 if is_cdna4 else 128
elif tokens_per_expt >= 512 and n >= 2048:
block_m = 256 if is_cdna4 else 128
elif is_cdna4 and m >= 512:
block_m = 128
else:
block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
if routing_data is not None:
grid_m = routing_data.n_blocks(m, block_m)
else:
grid_m = triton.cdiv(m, block_m)
# group_m:
group_m = 4
# number of xcds
num_xcds = 8
xcd_swizzle = num_xcds
# block_nk:
block_n, block_k = opt_flags_amd.compute_block_nk(
n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
)
# Replace block_k if provided in constraints.
# TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
if constraints.get("block_k", None) is not None:
block_k = constraints["block_k"]
if constraints.get("block_n", None) is not None:
block_n = constraints["block_n"]
is_persistent = constraints.get("is_persistent", False)
# split_k:
if constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
elif is_persistent or enforce_bitwise_invariance:
split_k = 1
else:
grid_size = grid_m * ((n + block_n - 1) // block_n)
n_cu = torch.cuda.get_device_properties(0).multi_processor_count
split_k = max(1, n_cu // grid_size)
# w_cache_modifier:
w_cache_modifier = ".cg" if block_m <= 32 else None
# num_warps, num_stages
num_warps = 2 if (m is not None and m <= 16) else 8
num_stages = 2
# AMD-specific
target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
epilogue_subtile = constraints.get('epilogue_subtile', None)
if epilogue_subtile is None:
epilogue_subtile = 1
ret = OptFlags(
block_m=block_m,
block_n=block_n,
block_k=block_k,
num_warps=num_warps,
num_stages=num_stages,
group_m=group_m,
xcd_swizzle=xcd_swizzle,
w_cache_modifier=w_cache_modifier,
split_k=split_k,
is_persistent=is_persistent,
fused_scatter=constraints.get('fused_scatter', False),
idle_sms=0,
epilogue_subtile=epilogue_subtile,
arch=None,
target_kernel_kwargs=target_kernel_kwargs,
)
# check constraints
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
return ret
def make_default_opt_flags_nvidia(
out_dtype,
lhs_dtype,
rhs_dtype,
precision_config,
m,
n,
k,
routing_data,
can_use_persistent_tma,
can_use_fused_scatter,
enforce_bitwise_invariance,
epilogue_effective_itemsize,
constraints,
):
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
# tokens per expert
if routing_data is None:
tokens_per_expt = m
elif routing_data.expected_tokens_per_expt is None:
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
else:
tokens_per_expt = routing_data.expected_tokens_per_expt
# pid swizzling
group_m = 8
xcd_swizzle = 1
# block_m
if constraints.get("block_m", None):
block_m = constraints["block_m"]
elif enforce_bitwise_invariance:
block_m = 128
else:
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
# block n
arch = None
block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
# is_persistent
grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
n_sms = torch.cuda.get_device_properties(0).multi_processor_count
tiles_per_sm = grid_size / n_sms
supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
if constraints.get("is_persistent", None) is not None:
is_persistent = constraints["is_persistent"]
else:
has_simple_epilogue = precision_config.max_num_imprecise_acc is None
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
# TEMP CHANGE
if precision_config.act_scale is not None or precision_config.out_scale is not None:
is_persistent = False
# block k
if constraints.get("block_k", None) is not None:
block_k = constraints["block_k"]
else:
block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
# split_k
if constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
split_k = 1
else:
estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
if split_k > 1:
# With split_k, results are written in f32. Use that for the following computations.
out_dtype = torch.float32
compute_num_stages_args = (
precision_config,
is_persistent,
block_m,
block_n,
block_k,
out_dtype,
lhs_dtype,
rhs_dtype,
)
if constraints.get("epilogue_subtile", None) is not None:
subtiles_to_check = [constraints["epilogue_subtile"]]
else:
subtiles_to_check = [1, 2, 4]
num_stages = -1
for ep in subtiles_to_check:
ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
if ns > num_stages:
epilogue_subtile, num_stages = ep, ns
assert num_stages >= 1
if constraints.get("num_stages", None):
num_stages = constraints["num_stages"]
# fused scatter scratchpad
if constraints.get("fused_scatter", None) is not None:
fused_scatter = constraints["fused_scatter"]
else:
fused_scatter = can_use_fused_scatter and split_k == 1
# Handshake with the HBM swizzling
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
ret = OptFlags(
block_m=block_m,
block_n=block_n,
block_k=block_k,
num_warps=num_warps,
num_stages=num_stages,
fused_scatter=fused_scatter,
group_m=group_m,
xcd_swizzle=xcd_swizzle,
w_cache_modifier=None,
split_k=split_k,
is_persistent=is_persistent,
epilogue_subtile=epilogue_subtile,
arch=arch,
target_kernel_kwargs=dict(),
idle_sms=constraints.get("idle_sms", 0),
)
# check constraints
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
return ret
# --------------
# User Interface
# --------------
_opt_flags_constraints: dict = dict()
_opt_flags: OptFlags | None = None
def update_opt_flags_constraints(constraints: dict[str, int]):
global _opt_flags_constraints
_opt_flags_constraints.update(constraints)
def reset_opt_flags_constraints():
global _opt_flags_constraints
_opt_flags_constraints = dict()
def set_opt_flags(opt_flags: OptFlags):
global _opt_flags
assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
assert not _opt_flags, "opt_flags already set; please reset to None first"
_opt_flags = opt_flags
class InapplicableConstraint(Exception):
pass
def make_opt_flags(
out_dtype,
lhs_dtype,
rhs_dtype,
precision_config,
m,
n,
k,
routing_data,
can_use_persistent_tma,
can_use_fused_scatter,
epilogue_effective_itemsize,
):
if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter:
raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint")
enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
if _opt_flags is not None:
assert not _opt_flags_constraints
return _opt_flags
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
routing_data, can_use_persistent_tma, can_use_fused_scatter,
enforce_bitwise_invariance, epilogue_effective_itemsize,
_opt_flags_constraints]
backend = triton.runtime.driver.active.get_current_target().backend
if backend == "hip":
return make_default_opt_flags_amd(*args)
if backend == "cuda":
return make_default_opt_flags_nvidia(*args)
assert False
import torch
import triton
from vllm.kvprune.triton_kernels.target_info import get_cdna_version
from vllm.kvprune.triton_kernels.tensor import bitwidth
def compute_block_nk(
n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
):
lhs_width = bitwidth(lhs_dtype) / 8
rhs_width = bitwidth(rhs_dtype) / 8
# block_n:
n_cu = torch.cuda.get_device_properties(0).multi_processor_count
if n is not None:
if n <= 128 and (n & (n - 1)) == 0:
block_n = n
else:
block_n = max(
32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))
)
elif block_m > 64:
block_n = 256
else:
block_n = 128
if get_cdna_version() == 4 and block_m == 128:
block_n = 512
# block_k needs to match the cacheline size (128B)
block_k = int(128 // min(lhs_width, rhs_width))
# TODO: block_k = 128 seems to work better for now.
# perhaps due to increased number of k loops to pipeline
if precision_config.weight_scale is not None and get_cdna_version() != 4:
block_k = 128
return block_n, block_k
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment