import math from dataclasses import dataclass from typing import Tuple, Optional, Literal from functools import lru_cache from contextlib import contextmanager import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist # from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn from kernel import sparse_attn, hc_split_sinkhorn try: from scipy.linalg import hadamard except ImportError: hadamard = None world_size = 1 rank = 0 block_size = 128 fp4_block_size = 32 default_dtype = torch.bfloat16 scale_fmt = None scale_dtype = torch.float32 @contextmanager def set_dtype(dtype): """Temporarily override torch default dtype, restoring it on exit (even if an exception occurs).""" prev = torch.get_default_dtype() torch.set_default_dtype(dtype) try: yield finally: torch.set_default_dtype(prev) @dataclass class ModelArgs: """Model hyperparameters. Field names match the config JSON keys.""" max_batch_size: int = 4 max_seq_len: int = 4096 dtype: Literal["bf16", "fp8"] = "fp8" scale_fmt: Literal[None, "ue8m0"] = "ue8m0" expert_dtype: Literal[None, "fp4", "fp8"] = None scale_dtype: Literal["fp32", "fp8"] = "fp8" vocab_size: int = 129280 dim: int = 4096 moe_inter_dim: int = 4096 n_layers: int = 7 n_hash_layers: int = 0 n_mtp_layers: int = 1 n_heads: int = 64 # moe n_routed_experts: int = 8 n_shared_experts: int = 1 n_activated_experts: int = 2 score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus" route_scale: float = 1. swiglu_limit: float = 0. # mqa q_lora_rank: int = 1024 head_dim: int = 512 rope_head_dim: int = 64 norm_eps: float = 1e-6 o_groups: int = 8 o_lora_rank: int = 1024 window_size: int = 128 compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0) # yarn compress_rope_theta: float = 40000.0 original_seq_len: int = 0 rope_theta: float = 10000.0 rope_factor: float = 40 beta_fast: int = 32 beta_slow: int = 1 # index index_n_heads: int = 64 index_head_dim: int = 128 index_topk: int = 512 # hc hc_mult: int = 4 hc_sinkhorn_iters: int = 20 hc_eps: float = 1e-6 class ParallelEmbedding(nn.Module): """Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows. Out-of-range indices are zero-masked before all_reduce to combine partial embeddings.""" def __init__(self, vocab_size: int, dim: int): super().__init__() self.vocab_size = vocab_size self.dim = dim assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) x = x - self.vocab_start_idx x[mask] = 0 y = F.embedding(x, self.weight) if world_size > 1: y[mask] = 0 dist.all_reduce(y) return y def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype. For quantized weights, x is first quantized to FP8 via act_quant.""" assert bias is None return F.linear(x, weight) class Linear(nn.Module): """Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling.""" def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): super().__init__() self.in_features = in_features self.out_features = out_features dtype = dtype or default_dtype if dtype == torch.float4_e2m1fn_x2: # FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4 # Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K) self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2)) scale_out_features = out_features scale_in_features = in_features // fp4_block_size self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu)) elif dtype == torch.float8_e4m3fn: self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) scale_out_features = (out_features + block_size - 1) // block_size scale_in_features = (in_features + block_size - 1) // block_size self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu)) else: self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) self.register_parameter("scale", None) if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: return linear(x, self.weight, self.bias) class ColumnParallelLinear(Linear): """Shards output dim across TP ranks. No all-reduce needed on output.""" def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: return linear(x, self.weight, self.bias) class RowParallelLinear(Linear): """Shards input dim across TP ranks. All-reduce on output to sum partial results.""" def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: y = linear(x, self.weight, None) if world_size > 1: y = y.float() dist.all_reduce(y) if self.bias is not None: y += self.bias return y.type_as(x) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps # rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient. self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) def forward(self, x: torch.Tensor): dtype = x.dtype x = x.float() var = x.square().mean(-1, keepdim=True) x = x * torch.rsqrt(var + self.eps) return (self.weight * x).to(dtype) @lru_cache(2) def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor: """Precomputes complex exponentials for rotary embeddings with YaRN scaling. When original_seq_len > 0, applies frequency interpolation with a smooth linear ramp between beta_fast and beta_slow correction ranges.""" def find_correction_dim(num_rotations, dim, base, max_seq_len): return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) return max(low, 0), min(high, dim-1) def linear_ramp_factor(min, max, dim): if min == max: max += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) if original_seq_len > 0: low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len) smooth = 1 - linear_ramp_factor(low, high, dim // 2) freqs = freqs / factor * (1 - smooth) + freqs * smooth t = torch.arange(seqlen) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor: """Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation).""" y = x x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2))) if inverse: freqs_cis = freqs_cis.conj() if x.ndim == 3: freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1)) else: freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) x = torch.view_as_real(x * freqs_cis).flatten(-2) y.copy_(x) return y def hadamard_transform_ref(x, scale=1.0): """ x: (..., dim) out: (..., dim) """ if hadamard is None: raise ImportError("Please install scipy") x_shape = x.shape dim = x.shape[-1] x = x.reshape(-1, dim) log_dim = math.ceil(math.log2(dim)) dim_padded = 2 ** log_dim if dim != dim_padded: x = F.pad(x, (0, dim_padded - dim)) out = F.linear(x, torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device)) out = out * scale return out[..., :dim].reshape(*x_shape) def rotate_activation(x: torch.Tensor) -> torch.Tensor: """Applies randomized Hadamard rotation to spread information across dims before FP8 quant.""" assert x.dtype == torch.bfloat16 # from fast_hadamard_transform import hadamard_transform return hadamard_transform_ref(x, scale=x.size(-1) ** -0.5) @lru_cache(1) def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int): if start_pos >= window_size - 1: start_pos %= window_size matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0) elif start_pos > 0: matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1) else: base = torch.arange(seqlen).unsqueeze(1) matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size)) matrix = torch.where(matrix > base, -1, matrix) return matrix.unsqueeze(0).expand(bsz, -1, -1) @lru_cache(2) def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int): if start_pos > 0: matrix = torch.arange(0, (start_pos + 1) // ratio) + offset else: matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1) mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio matrix = torch.where(mask, -1, matrix + offset) return matrix.unsqueeze(0).expand(bsz, -1, -1) class Compressor(nn.Module): """Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens. When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries.""" def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False): super().__init__() self.dim = args.dim self.head_dim = head_dim self.rope_head_dim = args.rope_head_dim self.nope_head_dim = head_dim - args.rope_head_dim self.compress_ratio = compress_ratio self.overlap = compress_ratio == 4 self.rotate = rotate coff = 1 + self.overlap self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32)) # wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient. # When overlap, the first half of dims is for overlapping compression, second half for normal. self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32) self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32) self.norm = RMSNorm(self.head_dim, args.norm_eps) self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache # State buffers for decode-phase incremental compression. # With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window. self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False) self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False) self.freqs_cis: torch.Tensor = None def overlap_transform(self, tensor: torch.Tensor, value=0): # tensor: [b,s,r,2d] b, s, _, _ = tensor.size() ratio, d = self.compress_ratio, self.head_dim new_tensor = tensor.new_full((b, s, 2 * ratio, d), value) new_tensor[:, :, ratio:] = tensor[:, :, :, d:] new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d] return new_tensor def forward(self, x: torch.Tensor, start_pos: int): assert self.kv_cache is not None bsz, seqlen, _ = x.size() ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim dtype = x.dtype # compression need fp32 x = x.float() kv = self.wkv(x) score = self.wgate(x) if start_pos == 0: should_compress = seqlen >= ratio remainder = seqlen % ratio cutoff = seqlen - remainder offset = ratio if overlap else 0 if overlap and cutoff >= ratio: self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff] self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape if remainder > 0: kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1) self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder] score = score[:, :cutoff] kv = kv.unflatten(1, (-1, ratio)) score = score.unflatten(1, (-1, ratio)) + self.ape if overlap: kv = self.overlap_transform(kv, 0) score = self.overlap_transform(score, float("-inf")) kv = (kv * score.softmax(dim=2)).sum(dim=2) else: should_compress = (start_pos + 1) % self.compress_ratio == 0 score += self.ape[start_pos % ratio] if overlap: self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1) self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1) if should_compress: kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1) score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1) kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True) self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:] self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:] else: self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1) self.score_state[:bsz, start_pos % ratio] = score.squeeze(1) if should_compress: kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True) if not should_compress: return kv = self.norm(kv.to(dtype)) if start_pos == 0: freqs_cis = self.freqs_cis[:cutoff:ratio] else: freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0) apply_rotary_emb(kv[..., -rd:], freqs_cis) if self.rotate: kv = rotate_activation(kv) # fp4_act_quant(kv, fp4_block_size, True) # else: # act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True) if start_pos == 0: self.kv_cache[:bsz, :seqlen // ratio] = kv else: self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1) return kv class Indexer(torch.nn.Module): """Selects top-k compressed KV positions for sparse attention via learned scoring. Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring.""" def __init__(self, args: ModelArgs, compress_ratio: int = 4): super().__init__() self.dim = args.dim self.n_heads = args.index_n_heads self.n_local_heads = args.index_n_heads // world_size self.head_dim = args.index_head_dim self.rope_head_dim = args.rope_head_dim self.index_topk = args.index_topk self.q_lora_rank = args.q_lora_rank self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim, dtype=torch.bfloat16) self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16) self.softmax_scale = self.head_dim ** -0.5 self.compress_ratio = compress_ratio self.compressor = Compressor(args, compress_ratio, self.head_dim, True) self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False) self.freqs_cis = None def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int): bsz, seqlen, _ = x.size() freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] ratio = self.compress_ratio rd = self.rope_head_dim end_pos = start_pos + seqlen if self.compressor.kv_cache is None: self.compressor.kv_cache = self.kv_cache self.compressor.freqs_cis = self.freqs_cis q = self.wq_b(qr) q = q.unflatten(-1, (self.n_local_heads, self.head_dim)) apply_rotary_emb(q[..., -rd:], freqs_cis) q = rotate_activation(q) # use fp4 simulation for q and kv in indexer # fp4_act_quant(q, fp4_block_size, True) self.compressor(x, start_pos) weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5) # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16 index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio]) index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2) if world_size > 1: dist.all_reduce(index_score) if start_pos == 0: mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio index_score += torch.where(mask, float("-inf"), 0) topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1] if start_pos == 0: mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio topk_idxs = torch.where(mask, -1, topk_idxs + offset) else: topk_idxs += offset return topk_idxs class Attention(nn.Module): """Multi-head Latent Attention (MLA) with sliding window + optional KV compression. Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection.""" def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.layer_id = layer_id self.dim = args.dim self.n_heads = args.n_heads self.n_local_heads = args.n_heads // world_size self.q_lora_rank = args.q_lora_rank self.o_lora_rank = args.o_lora_rank self.head_dim = args.head_dim self.rope_head_dim = args.rope_head_dim self.nope_head_dim = args.head_dim - args.rope_head_dim self.n_groups = args.o_groups self.n_local_groups = self.n_groups // world_size self.window_size = args.window_size self.compress_ratio = args.compress_ratios[layer_id] self.eps = args.norm_eps self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32)) self.wq_a = Linear(self.dim, self.q_lora_rank, dtype=torch.bfloat16) self.q_norm = RMSNorm(self.q_lora_rank, self.eps) self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim, dtype=torch.bfloat16) self.wkv = Linear(self.dim, self.head_dim, dtype=torch.bfloat16) self.kv_norm = RMSNorm(self.head_dim, self.eps) self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16) self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim) self.softmax_scale = self.head_dim ** -0.5 if self.compress_ratio: self.compressor = Compressor(args, self.compress_ratio, self.head_dim) if self.compress_ratio == 4: self.indexer = Indexer(args, self.compress_ratio) else: self.indexer = None kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0) self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False) if self.compress_ratio: original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta else: # disable YaRN and use base rope_theta in pure sliding-window attention original_seq_len, rope_theta = 0, args.rope_theta freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len, rope_theta, args.rope_factor, args.beta_fast, args.beta_slow) self.register_buffer("freqs_cis", freqs_cis, persistent=False) def forward(self, x: torch.Tensor, start_pos: int): bsz, seqlen, _ = x.size() freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] win = self.window_size ratio = self.compress_ratio rd = self.rope_head_dim if self.compress_ratio and self.compressor.kv_cache is None: self.compressor.kv_cache = self.kv_cache[:, win:] self.compressor.freqs_cis = self.freqs_cis if self.indexer is not None: self.indexer.freqs_cis = self.freqs_cis # q qr = q = self.q_norm(self.wq_a(x)) q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim)) q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps) apply_rotary_emb(q[..., -rd:], freqs_cis) # win kv & topk_idxs kv = self.wkv(x) kv = self.kv_norm(kv) apply_rotary_emb(kv[..., -rd:], freqs_cis) # FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision # act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True) topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos) if self.compress_ratio: offset = kv.size(1) if start_pos == 0 else win if self.indexer is not None: compress_topk_idxs = self.indexer(x, qr, start_pos, offset) else: compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset) topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1) topk_idxs = topk_idxs.int() # compress kv & attn if start_pos == 0: if seqlen <= win: self.kv_cache[:bsz, :seqlen] = kv else: cutoff = seqlen % win self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1) if self.compress_ratio: if (kv_compress := self.compressor(x, start_pos)) is not None: kv = torch.cat([kv, kv_compress], dim=1) # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16 o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale) else: self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1) if self.compress_ratio: self.compressor(x, start_pos) o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale) apply_rotary_emb(o[..., -rd:], freqs_cis, True) # o o = o.view(bsz, seqlen, self.n_local_groups, -1) wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) # NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf, # but using BF16 for simplicity. o = torch.einsum("bsgd,grd->bsgr", o, wo_a) x = self.wo_b(o.flatten(2)) return x class Gate(nn.Module): """MoE gating: computes expert routing scores and selects top-k experts. Supports hash-based routing (first n_hash_layers) where expert indices are predetermined per token ID, and score-based routing (remaining layers).""" def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.dim = args.dim self.topk = args.n_activated_experts self.score_func = args.score_func self.route_scale = args.route_scale self.hash = layer_id < args.n_hash_layers self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) if self.hash: self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False) self.bias = None else: self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: scores = linear(x.float(), self.weight.float()) if self.score_func == "softmax": scores = scores.softmax(dim=-1) elif self.score_func == "sigmoid": scores = scores.sigmoid() else: scores = F.softplus(scores).sqrt() original_scores = scores # Bias shifts scores for expert selection (topk) but does not affect routing weights. if self.bias is not None: scores = scores + self.bias if self.hash: indices = self.tid2eid[input_ids] else: indices = scores.topk(self.topk, dim=-1)[1] weights = original_scores.gather(1, indices) if self.score_func != "softmax": weights /= weights.sum(dim=-1, keepdim=True) weights *= self.route_scale return weights, indices class Expert(nn.Module): """Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability.""" def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0): super().__init__() self.w1 = Linear(dim, inter_dim, dtype=dtype) self.w2 = Linear(inter_dim, dim, dtype=dtype) self.w3 = Linear(dim, inter_dim, dtype=dtype) self.swiglu_limit = swiglu_limit def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: dtype = x.dtype gate = self.w1(x).float() up = self.w3(x).float() if self.swiglu_limit > 0: up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit) gate = torch.clamp(gate, max=self.swiglu_limit) x = F.silu(gate) * up if weights is not None: x = weights * x return self.w2(x.to(dtype)) class MoE(nn.Module): """Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert. Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts.""" def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.layer_id = layer_id self.dim = args.dim assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" self.n_routed_experts = args.n_routed_experts self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts self.experts_start_idx = rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.gate = Gate(layer_id, args) if args.expert_dtype == "fp4": expert_dtype = torch.float4_e2m1fn_x2 elif args.expert_dtype == "fp8": expert_dtype = torch.float8_e4m3fn else: None # expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=torch.bfloat16, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None for i in range(self.n_routed_experts)]) assert args.n_shared_experts == 1 # no swiglu_limit self.shared_experts = Expert(args.dim, args.moe_inter_dim, dtype=torch.bfloat16) def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: shape = x.size() x = x.view(-1, self.dim) weights, indices = self.gate(x, input_ids.flatten()) y = torch.zeros_like(x, dtype=torch.float32) #### torch.cuda.synchronize() indices_cpu = indices.flatten().cpu() counts_cpu = torch.bincount(indices_cpu, minlength=self.n_routed_experts) counts = counts_cpu.cuda() #counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() torch.cuda.synchronize() for i in range(self.experts_start_idx, self.experts_end_idx): if counts[i] == 0: continue expert = self.experts[i] idx, top = torch.where(indices == i) y[idx] += expert(x[idx], weights[idx, top, None]) if world_size > 1: dist.all_reduce(y) y += self.shared_experts(x) return y.type_as(x).view(shape) class Block(nn.Module): """Transformer block with Hyper-Connections (HC) mixing. Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state. hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn). hc_post: expands 1 -> hc copies via learned post-weights + combination matrix.""" def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.layer_id = layer_id self.norm_eps = args.norm_eps self.attn = Attention(layer_id, args) self.ffn = MoE(layer_id, args) self.attn_norm = RMSNorm(args.dim, self.norm_eps) self.ffn_norm = RMSNorm(args.dim, self.norm_eps) self.hc_mult = hc_mult = args.hc_mult self.hc_sinkhorn_iters = args.hc_sinkhorn_iters self.hc_eps = args.hc_eps mix_hc = (2 + hc_mult) * hc_mult hc_dim = hc_mult * args.dim with set_dtype(torch.float32): self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim)) self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim)) self.hc_attn_base = nn.Parameter(torch.empty(mix_hc)) self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc)) self.hc_attn_scale = nn.Parameter(torch.empty(3)) self.hc_ffn_scale = nn.Parameter(torch.empty(3)) def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor): # x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d] shape, dtype = x.size(), x.dtype x = x.flatten(2).float() rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) mixes = F.linear(x, hc_fn) * rsqrt pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps) y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2) return y.to(dtype), post, comb def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor): # x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d] y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2) return y.type_as(x) def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor: residual = x x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base) x = self.attn_norm(x) x = self.attn(x, start_pos) x = self.hc_post(x, residual, post, comb) residual = x x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base) x = self.ffn_norm(x) x = self.ffn(x, input_ids) x = self.hc_post(x, residual, post, comb) return x class ParallelHead(nn.Module): def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6): super().__init__() self.vocab_size = vocab_size self.dim = dim self.norm_eps = norm_eps self.hc_eps = hc_eps self.part_vocab_size = (vocab_size // world_size) # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later. self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32)) def get_logits(self, x): return F.linear(x[:, -1].float(), self.weight) def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm): # x: [b,s,hc,d] x = self.hc_head(x, hc_fn, hc_scale, hc_base) logits = self.get_logits(norm(x)) if world_size > 1: all_logits = [torch.empty_like(logits) for _ in range(world_size)] dist.all_gather(all_logits, logits) logits = torch.cat(all_logits, dim=-1) return logits def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor): shape, dtype = x.size(), x.dtype x = x.flatten(2).float() rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) mixes = F.linear(x, hc_fn) * rsqrt pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2) return y.to(dtype) class MTPBlock(Block): def __init__(self, layer_id: int, args: ModelArgs): super().__init__(layer_id, args) self.e_proj = Linear(args.dim, args.dim, dtype=torch.bfloat16) self.h_proj = Linear(args.dim, args.dim, dtype=torch.bfloat16) self.enorm = RMSNorm(args.dim, args.norm_eps) self.hnorm = RMSNorm(args.dim, args.norm_eps) self.norm = RMSNorm(args.dim, args.norm_eps) self.hc_mult = hc_mult = args.hc_mult hc_dim = hc_mult * args.dim with set_dtype(torch.float32): self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim)) self.hc_head_base = nn.Parameter(torch.empty(hc_mult)) self.hc_head_scale = nn.Parameter(torch.empty(1)) self.embed: ParallelEmbedding = None self.head: ParallelHead = None @torch.inference_mode() def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor: # x: [b,s,hc,d] assert self.embed is not None and self.head is not None e = self.embed(input_ids) e = self.enorm(e) x = self.hnorm(x) x = self.e_proj(e).unsqueeze(2) + self.h_proj(x) x = super().forward(x, start_pos, input_ids) logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm) return logits class Transformer(nn.Module): """Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits. Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__.""" def __init__(self, args: ModelArgs): # global world_size, rank, default_dtype, scale_fmt, scale_dtype global world_size, rank world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32 super().__init__() self.max_seq_len = args.max_seq_len self.norm_eps = args.norm_eps self.hc_eps = args.hc_eps self.embed = ParallelEmbedding(args.vocab_size, args.dim) self.layers = torch.nn.ModuleList() for layer_id in range(args.n_layers): self.layers.append(Block(layer_id, args)) self.norm = RMSNorm(args.dim, self.norm_eps) self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps) self.mtp = torch.nn.ModuleList() for layer_id in range(args.n_mtp_layers): self.mtp.append(MTPBlock(args.n_layers + layer_id, args)) self.mtp[-1].embed = self.embed self.mtp[-1].head = self.head self.hc_mult = hc_mult = args.hc_mult hc_dim = hc_mult * args.dim with set_dtype(torch.float32): self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim)) self.hc_head_base = nn.Parameter(torch.empty(hc_mult)) self.hc_head_scale = nn.Parameter(torch.empty(1)) @torch.inference_mode() def forward(self, input_ids: torch.Tensor, start_pos: int = 0): h = self.embed(input_ids) # Expand to hc_mult copies for Hyper-Connections h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1) for layer in self.layers: h = layer(h, start_pos, input_ids) logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm) return logits if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) torch.set_default_device("cuda") torch.manual_seed(0) args = ModelArgs(n_hash_layers=0) x = torch.randint(0, args.vocab_size, (2, 128)) model = Transformer(args) print(model(x).size()) for i in range(128, 150): print(i, model(x[:, 0:1], i).size()) h = torch.randn(2, 128, args.hc_mult, args.dim) mtp = model.mtp[0] print(mtp(h, 0, x).size()) print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())