Commit 9c26b762 authored by chenych's avatar chenych
Browse files

Add fp8 inference

parent f229f29b
......@@ -15,12 +15,11 @@ DeepSeek-V4 系列在架构与优化方面引入了多项关键升级:
| 软件 | 版本 |
| :------: |:-------:|
| DTK | 26.04 |
| python | 3.10.12 |
| torch | 2.9.0+das.opt1.dtk2604.20260331.g4e3c1e7 |
| tilelang | 0.1.7.post3+cpu.git52700923 |
| Python | 3.10.12 |
| Torch | 2.9.0+das.opt1.dtk2604.20260331.g4e3c1e7 |
| Tilelang | 0.1.7.post3+cpu.git52700923 |
当前仅支持镜像:harbor.sourcefind.cn:5443/dcu/admin/base/custom:torch-2.9.0-ubuntu22.04-dtk26.04-deepseek-v4-0425
当前仅支持镜像: harbor.sourcefind.cn:5443/dcu/admin/base/custom:torch-2.9.0-ubuntu22.04-dtk26.04-deepseek-v4-0425
- 挂载地址`-v`根据实际模型情况修改
......@@ -44,8 +43,6 @@ docker run -it \
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
## 数据集
`暂无`
......@@ -55,11 +52,11 @@ docker run -it \
## 推理
### Pytorch
#### 单机推理
##### BF16
1. 模型转换与切分
```bash
#注意将脚本中对应的路径及参数设置成用户实际值
#其中:INPUT_FP8_HF_PATH为模型下载路径;OUTPUT_BF16_HF_PATH为bf16模型存放路径;SAVE_PATH为切分好的模型路径;mp根据实际卡数调整
#其中:INPUT_FP8_HF_PATH为原始模型路径;OUTPUT_BF16_HF_PATH为bf16模型存放路径;SAVE_PATH为切分好的模型路径;mp根据实际卡数调整
cd convert_weight
bash convert_weight.sh
```
......@@ -67,10 +64,25 @@ bash convert_weight.sh
2. 启动对话推理
```bash
#注意将脚本中对应的路径及参数设置成用户实际值
cd ../inference
cd ../inference-bf16
sh start_torch.sh
```
##### FP8
1. 模型转换与切分
```bash
#注意将脚本中对应的路径及参数设置成用户实际值
#其中:--hf-ckpt-path为原始模型路径;--save-path为切分好的FP8模型的存放路径;MP根据实际卡数调整(默认为8)
cd inference-fp8
bash cast_fp4_to_fp8.sh
```
2. 启动对话推理
```bash
#注意将脚本中对应的路径及参数设置成用户实际值
sh start_torch_fp8.sh
```
## 效果展示
**注意**:首次对话时由于kernel编译,可能会出现dtk hipcc编译警告,属于正常现象
......
#!/bin/bash
export EXPERTS=256
export MP=8
python convert.py --hf-ckpt-path deepseek-ai/DeepSeek-V4-Flash \
--save-path /path/of/DeepSeek-V4-Flash-FP8-MP8 \
--n-experts ${EXPERTS} \
--model-parallel ${MP} \
--expert-dtype fp8
{
"vocab_size": 129280,
"dim": 4096,
"moe_inter_dim": 2048,
"n_layers": 43,
"n_hash_layers": 3,
"n_heads": 64,
"n_routed_experts": 256,
"n_shared_experts": 1,
"n_activated_experts": 6,
"score_func": "sqrtsoftplus",
"route_scale": 1.5,
"swiglu_limit": 10.0,
"q_lora_rank": 1024,
"head_dim": 512,
"rope_head_dim": 64,
"o_groups": 8,
"o_lora_rank": 1024,
"window_size": 128,
"original_seq_len": 65536,
"rope_theta": 10000,
"rope_factor": 16,
"beta_fast": 32,
"beta_slow": 1,
"index_n_heads": 64,
"index_head_dim": 128,
"index_topk": 512,
"hc_mult": 4,
"hc_sinkhorn_iters": 20,
"dtype": "fp8",
"scale_fmt": "ue8m0",
"compress_rope_theta": 160000,
"compress_ratios": [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
}
\ No newline at end of file
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
FP4_TABLE = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
], dtype=torch.float32)
def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Casts a tensor from e2m1fn to e4m3fn losslessly.
"""
assert x.dtype == torch.int8
assert x.ndim == 2
out_dim, in_dim = x.size()
in_dim *= 2
fp8_block_size = 128
fp4_block_size = 32
assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
x = x.view(torch.uint8)
low = x & 0x0F
high = (x >> 4) & 0x0F
x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
# max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
# 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
MAX_OFFSET_BITS = 6
bOut = out_dim // fp8_block_size
bIn = in_dim // fp8_block_size
# bOut, bIn, 128, 128
x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
# bOut, bIn, 128*4
scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
## bOut, bIn, 1
scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
# bOut, bIn, 128*4
offset = scale / scale_max_offset_bits
# bOut, bIn, 128, 128
offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"lm_head": ("head", 0),
"embed": ("embed", 0),
"wq_b": ("wq_b", 0),
"wo_a": ("wo_a", 0),
"wo_b": ("wo_b", 1),
"head": ("head", 0),
"attn_sink": ("attn_sink", 0),
"weights_proj": ("weights_proj", 0),
}
def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
"""
Converts and saves model checkpoint files into a specified format.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns:
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
continue
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
key = name.split(".")[-1]
else:
key = name.split(".")[-2]
if key in mapping:
new_key, dim = mapping[key]
else:
new_key, dim = key, None
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
names = list(state_dicts[i].keys())
for name in names:
if name.endswith("wo_a.weight"):
weight = state_dicts[i][name]
scale = state_dicts[i].pop(name.replace("weight", "scale"))
weight = weight.unflatten(0, (-1, 128)).unflatten(-1, (-1, 128)).float() * scale[:, None, :, None].float()
state_dicts[i][name] = weight.flatten(2, 3).flatten(0, 1).bfloat16()
elif "experts" in name and state_dicts[i][name].dtype == torch.int8:
if expert_dtype == "fp8":
scale_name = name.replace("weight", "scale")
weight = state_dicts[i].pop(name)
scale = state_dicts[i].pop(scale_name)
state_dicts[i][name], state_dicts[i][scale_name] = cast_e2m1fn_to_e4m3fn(weight, scale)
else:
state_dicts[i][name] = state_dicts[i][name].view(torch.float4_e2m1fn_x2)
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for file in ["tokenizer.json", "tokenizer_config.json"]:
old_file_path = os.path.join(hf_ckpt_path, file)
new_file_path = os.path.join(save_path, file)
if os.path.exists(old_file_path):
shutil.copyfile(old_file_path, new_file_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
import os
import json
import sys
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
def sample(logits, temperature: float = 1.0):
"""Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
since it avoids the GPU-to-CPU sync in torch.multinomial."""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""Batch generation with left-padded prompts.
The first forward pass processes [min_prompt_len:] tokens (prefill phase).
Subsequent passes generate one token at a time (decode phase). For positions
still within a prompt, the ground-truth token overrides the model's prediction.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens))
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
toks.append(eos_id)
completion_tokens.append(toks)
return completion_tokens
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(33377335)
with open(config) as f:
args = ModelArgs(**json.load(f))
if interactive:
args.max_batch_size = 1
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print("load model")
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
torch.set_default_device("cuda")
print("I'm DeepSeek 👋")
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0])
print(completion)
messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
else:
with open(input_file) as f:
prompts = f.read().split("\n\n")
prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=300)
parser.add_argument("--temperature", type=float, default=0.6)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
import torch
import tilelang
import tilelang.language as T
from typing import Tuple, Optional
tilelang.set_log_level("WARNING")
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
}
FP8 = "float8_e4m3"
FP4 = "float4_e2m1fn"
# FE8M0 = "float8_e8m0fnu"
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
def fast_log2_ceil(x):
"""Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
"""Compute 2^x for integer x via IEEE 754 bit manipulation."""
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
round_scale=False, inplace=False
):
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale or inplace else 2
blk_m = 32
group_size = block_size
# Internal computation in FP32; scale_dtype controls output storage format.
compute_dtype = FP32
out_dtype = in_dtype if inplace else out_dtype
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
s_local = T.alloc_fragment((blk_m,), compute_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
if inplace:
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.Cast(
out_dtype,
T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
))) * s_local[i],
)
else:
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
) -> torch.Tensor:
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
N = x.size(-1)
assert N % block_size == 0
# tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
tl_dtype = FP32
z = x.contiguous()
y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
kernel = act_quant_kernel(
N, block_size, scale_dtype=tl_dtype,
round_scale=scale_fmt is not None, inplace=inplace,
)
kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
if inplace:
x.copy_(y)
return x
return y, s
@tilelang.jit(pass_configs=pass_configs)
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
assert out_dtype in [BF16, FP32]
M = T.symbolic("M")
group_size = 128
block_M = 32
block_N = 128
block_K = 128
@T.prim_func
def fp8_gemm_kernel_(
A: T.Tensor[(M, K), FP8],
B: T.Tensor[(N, K), FP8],
C: T.Tensor[(M, N), out_dtype],
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), FP8)
B_shared = T.alloc_shared((block_N, block_K), FP8)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
Scale_C_shared = T.alloc_shared((block_M), FP32)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=2):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
# Cast scales to FP32 for computation; scales_b has one value per block_N group
Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
for i in T.Parallel(block_M):
Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Separate accumulator for scale-corrected results (2x accumulation precision)
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return fp8_gemm_kernel_
def fp8_gemm(
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
scale_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
assert a_s.is_contiguous() and b_s.is_contiguous(), (
"Scaling factor tensors must be contiguous"
)
# tl_dtype = FP32
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
return c
@tilelang.jit(pass_configs=pass_configs)
def sparse_attn_kernel(h_orig: int, d: int, scale=None):
"""Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
with numerically stable running max/sum, and includes a learnable attn_sink bias."""
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
topk = T.symbolic("topk")
if scale is None:
scale = (1.0 / d) ** 0.5
num_stages = 0
threads = 256
block = 32
num_blocks = tilelang.cdiv(topk, block)
padded_H = max(tilelang.math.next_power_of_2(h_orig), 16)
max_block_m = 16
if h_orig > max_block_m:
assert h_orig % max_block_m == 0, f"h should be a multiple of {max_block_m}"
REPLICATE_H = h_orig // max_block_m
else:
REPLICATE_H = 1
h = padded_H if REPLICATE_H == 1 else max_block_m
@T.prim_func
def sparse_attn_kernel_(
q: T.Tensor[(b, m, h_orig, d), BF16],
kv: T.Tensor[(b, n, d), BF16],
o: T.Tensor[(b, m, h_orig, d), BF16],
attn_sink: T.Tensor[(h_orig,), FP32],
topk_idxs: T.Tensor[(b, m, topk), INT32],
):
with T.Kernel(m * REPLICATE_H, b, threads=threads) as (bx, by):
q_shared = T.alloc_fragment((h, d), BF16)
kv_shared = T.alloc_shared((block, d), BF16)
# o_shared = T.alloc_shared((h, d), BF16)
acc_s_cast = T.alloc_shared((h, block), BF16)
idxs = T.alloc_fragment(block, INT32)
acc_s = T.alloc_fragment((h, block), FP32)
acc_o = T.alloc_fragment((h, d), FP32)
scores_max = T.alloc_fragment(h, FP32)
scores_max_prev = T.alloc_fragment(h, FP32)
scores_scale = T.alloc_fragment(h, FP32)
scores_sum = T.alloc_fragment(h, FP32)
sum_exp = T.alloc_fragment(h, FP32)
T.clear(acc_o)
T.clear(sum_exp)
T.fill(scores_max, -T.infinity(FP32))
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
H0 = (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * h)
H1 = H0 + h
T.copy(q[by, s_i, H0:H1, :], q_shared)
for t in T.Pipelined(num_blocks, num_stages=num_stages):
for i in T.Parallel(block):
idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, s_i, t * block + i], -1)
for i, j in T.Parallel(block, d):
kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
for i, j in T.Parallel(h, block):
acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(h, block):
acc_s[i, j] *= scale
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(h):
scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
for i, j in T.Parallel(h, block):
acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(h):
sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(h, d):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(h):
sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
for i, j in T.Parallel(h, d):
acc_o[i, j] /= sum_exp[i]
o_shared = T.alloc_shared((h, d), BF16)
T.copy(acc_o, o_shared)
T.copy(o_shared, o[by, s_i, H0:H1, :])
return sparse_attn_kernel_
def sparse_attn(
q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
) -> torch.Tensor:
b, s, h, d = q.size()
# Pad heads to 16 for kernel efficiency (stripped after)
if h < 16:
q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
o = torch.empty_like(q)
kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
kernel(q, kv, o, attn_sink, topk_idxs)
if h < 16:
o = o.narrow(2, 0, h).contiguous()
return o
@tilelang.jit(pass_configs=pass_configs)
def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
n = T.symbolic("n")
mix_hc = (2 + hc) * hc
threads = 64
@T.prim_func
def hc_split_sinkhorn_kernel_(
mixes: T.Tensor[(n, mix_hc), FP32],
hc_scale: T.Tensor[(3,), FP32],
hc_base: T.Tensor[(mix_hc,), FP32],
pre: T.Tensor[(n, hc), FP32],
post: T.Tensor[(n, hc), FP32],
comb: T.Tensor[(n, hc, hc), FP32],
):
with T.Kernel(n, threads=threads) as i:
mixes_shared = T.alloc_shared(mix_hc, FP32)
comb_frag = T.alloc_fragment((hc, hc), FP32)
T.copy(mixes[i, :], mixes_shared)
for j in T.Parallel(hc):
pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
for j in T.Parallel(hc):
post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
row_sum = T.alloc_fragment(hc, FP32)
col_sum = T.alloc_fragment(hc, FP32)
# comb = comb.softmax(-1) + eps
row_max = T.alloc_fragment(hc, FP32)
T.reduce_max(comb_frag, row_max, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
T.reduce_sum(comb_frag, row_sum, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(comb_frag, col_sum, dim=0)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
for _ in T.serial(sinkhorn_iters - 1):
# comb = comb / (comb.sum(-1) + eps)
T.reduce_sum(comb_frag, row_sum, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(comb_frag, col_sum, dim=0)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
T.copy(comb_frag, comb[i, :, :])
return hc_split_sinkhorn_kernel_
def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
b, s, _ = mixes.size()
pre = mixes.new_empty(b, s, hc_mult)
post = mixes.new_empty(b, s, hc_mult)
comb = mixes.new_empty(b, s, hc_mult, hc_mult)
kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
return pre, post, comb
import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal
from functools import lru_cache
from contextlib import contextmanager
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from kernel import act_quant, fp8_gemm, sparse_attn, hc_split_sinkhorn
try:
from scipy.linalg import hadamard
except ImportError:
hadamard = None
world_size = 1
rank = 0
block_size = 128
fp4_block_size = 32
default_dtype = torch.bfloat16
scale_fmt = None
scale_dtype = torch.float32
@contextmanager
def set_dtype(dtype):
"""Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
prev = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(prev)
@dataclass
class ModelArgs:
"""Model hyperparameters. Field names match the config JSON keys."""
max_batch_size: int = 4
max_seq_len: int = 4096
dtype: Literal["bf16", "fp8"] = "fp8"
scale_fmt: Literal[None, "ue8m0"] = None
expert_dtype: Literal[None, "fp4", "fp8"] = None
scale_dtype: Literal["fp32", "fp8"] = "fp32"
vocab_size: int = 129280
dim: int = 4096
moe_inter_dim: int = 4096
n_layers: int = 7
n_hash_layers: int = 0
n_mtp_layers: int = 1
n_heads: int = 64
# moe
n_routed_experts: int = 8
n_shared_experts: int = 1
n_activated_experts: int = 2
score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
route_scale: float = 1.
swiglu_limit: float = 0.
# mqa
q_lora_rank: int = 1024
head_dim: int = 512
rope_head_dim: int = 64
norm_eps: float = 1e-6
o_groups: int = 8
o_lora_rank: int = 1024
window_size: int = 128
compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
# yarn
compress_rope_theta: float = 40000.0
original_seq_len: int = 0
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
# index
index_n_heads: int = 64
index_head_dim: int = 128
index_topk: int = 512
# hc
hc_mult: int = 4
hc_sinkhorn_iters: int = 20
hc_eps: float = 1e-6
class ParallelEmbedding(nn.Module):
"""Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
def __init__(self, vocab_size: int, dim: int):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
self.part_vocab_size = (vocab_size // world_size)
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0
dist.all_reduce(y)
return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
For quantized weights, x is first quantized to FP8 via act_quant."""
assert bias is None
if weight.dtype == torch.float4_e2m1fn_x2:
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
elif weight.dtype == torch.float8_e4m3fn:
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
else:
return F.linear(x, weight)
class Linear(nn.Module):
"""Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
dtype = dtype or default_dtype
if dtype == torch.float4_e2m1fn_x2:
# FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
# Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
scale_out_features = out_features
scale_in_features = in_features // fp4_block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
elif dtype == torch.float8_e4m3fn:
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
# self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
else:
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
self.register_parameter("scale", None)
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return linear(x, self.weight, self.bias)
class ColumnParallelLinear(Linear):
"""Shards output dim across TP ranks. No all-reduce needed on output."""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return linear(x, self.weight, self.bias)
class RowParallelLinear(Linear):
"""Shards input dim across TP ranks. All-reduce on output to sum partial results."""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = linear(x, self.weight, None)
if world_size > 1:
y = y.float()
dist.all_reduce(y)
if self.bias is not None:
y += self.bias
return y.type_as(x)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
dtype = x.dtype
x = x.float()
var = x.square().mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
@lru_cache(2)
def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
"""Precomputes complex exponentials for rotary embeddings with YaRN scaling.
When original_seq_len > 0, applies frequency interpolation with a smooth
linear ramp between beta_fast and beta_slow correction ranges."""
def find_correction_dim(num_rotations, dim, base, max_seq_len):
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim-1)
def linear_ramp_factor(min, max, dim):
if min == max:
max += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if original_seq_len > 0:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth
t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
"""Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
y = x
x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
if inverse:
freqs_cis = freqs_cis.conj()
if x.ndim == 3:
freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
else:
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
x = torch.view_as_real(x * freqs_cis).flatten(-2)
y.copy_(x)
return y
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
if hadamard is None:
raise ImportError("Please install scipy")
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2 ** log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(x, torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device))
out = out * scale
return out[..., :dim].reshape(*x_shape)
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
"""Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
assert x.dtype == torch.bfloat16
# from fast_hadamard_transform import hadamard_transform
return hadamard_transform_ref(x, scale=x.size(-1) ** -0.5)
@lru_cache(1)
def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
if start_pos >= window_size - 1:
start_pos %= window_size
matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
elif start_pos > 0:
matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
else:
base = torch.arange(seqlen).unsqueeze(1)
matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
matrix = torch.where(matrix > base, -1, matrix)
return matrix.unsqueeze(0).expand(bsz, -1, -1)
@lru_cache(2)
def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
if start_pos > 0:
matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
else:
matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
matrix = torch.where(mask, -1, matrix + offset)
return matrix.unsqueeze(0).expand(bsz, -1, -1)
class Compressor(nn.Module):
"""Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
super().__init__()
self.dim = args.dim
self.head_dim = head_dim
self.rope_head_dim = args.rope_head_dim
self.nope_head_dim = head_dim - args.rope_head_dim
self.compress_ratio = compress_ratio
self.overlap = compress_ratio == 4
self.rotate = rotate
coff = 1 + self.overlap
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
# When overlap, the first half of dims is for overlapping compression, second half for normal.
self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
self.norm = RMSNorm(self.head_dim, args.norm_eps)
self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
# State buffers for decode-phase incremental compression.
# With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
self.freqs_cis: torch.Tensor = None
def overlap_transform(self, tensor: torch.Tensor, value=0):
# tensor: [b,s,r,2d]
b, s, _, _ = tensor.size()
ratio, d = self.compress_ratio, self.head_dim
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
return new_tensor
def forward(self, x: torch.Tensor, start_pos: int):
assert self.kv_cache is not None
bsz, seqlen, _ = x.size()
ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
dtype = x.dtype
# compression need fp32
x = x.float()
kv = self.wkv(x)
score = self.wgate(x)
if start_pos == 0:
should_compress = seqlen >= ratio
remainder = seqlen % ratio
cutoff = seqlen - remainder
offset = ratio if overlap else 0
if overlap and cutoff >= ratio:
self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
if remainder > 0:
kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
score = score[:, :cutoff]
kv = kv.unflatten(1, (-1, ratio))
score = score.unflatten(1, (-1, ratio)) + self.ape
if overlap:
kv = self.overlap_transform(kv, 0)
score = self.overlap_transform(score, float("-inf"))
kv = (kv * score.softmax(dim=2)).sum(dim=2)
else:
should_compress = (start_pos + 1) % self.compress_ratio == 0
score += self.ape[start_pos % ratio]
if overlap:
self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
if should_compress:
kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
else:
self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
if should_compress:
kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
if not should_compress:
return
kv = self.norm(kv.to(dtype))
if start_pos == 0:
freqs_cis = self.freqs_cis[:cutoff:ratio]
else:
freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
apply_rotary_emb(kv[..., -rd:], freqs_cis)
if self.rotate:
kv = rotate_activation(kv)
# fp4_act_quant(kv, fp4_block_size, True)
else:
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
if start_pos == 0:
self.kv_cache[:bsz, :seqlen // ratio] = kv
else:
self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
return kv
class Indexer(torch.nn.Module):
"""Selects top-k compressed KV positions for sparse attention via learned scoring.
Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
def __init__(self, args: ModelArgs, compress_ratio: int = 4):
super().__init__()
self.dim = args.dim
self.n_heads = args.index_n_heads
self.n_local_heads = args.index_n_heads // world_size
self.head_dim = args.index_head_dim
self.rope_head_dim = args.rope_head_dim
self.index_topk = args.index_topk
self.q_lora_rank = args.q_lora_rank
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
self.softmax_scale = self.head_dim ** -0.5
self.compress_ratio = compress_ratio
self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
self.freqs_cis = None
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
bsz, seqlen, _ = x.size()
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
ratio = self.compress_ratio
rd = self.rope_head_dim
end_pos = start_pos + seqlen
if self.compressor.kv_cache is None:
self.compressor.kv_cache = self.kv_cache
self.compressor.freqs_cis = self.freqs_cis
q = self.wq_b(qr)
q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
apply_rotary_emb(q[..., -rd:], freqs_cis)
q = rotate_activation(q)
# use fp4 simulation for q and kv in indexer
# fp4_act_quant(q, fp4_block_size, True)
self.compressor(x, start_pos)
weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
if world_size > 1:
dist.all_reduce(index_score)
if start_pos == 0:
mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
index_score += torch.where(mask, float("-inf"), 0)
topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
if start_pos == 0:
mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
topk_idxs = torch.where(mask, -1, topk_idxs + offset)
else:
topk_idxs += offset
return topk_idxs
class Attention(nn.Module):
"""Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
self.q_lora_rank = args.q_lora_rank
self.o_lora_rank = args.o_lora_rank
self.head_dim = args.head_dim
self.rope_head_dim = args.rope_head_dim
self.nope_head_dim = args.head_dim - args.rope_head_dim
self.n_groups = args.o_groups
self.n_local_groups = self.n_groups // world_size
self.window_size = args.window_size
self.compress_ratio = args.compress_ratios[layer_id]
self.eps = args.norm_eps
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
self.wkv = Linear(self.dim, self.head_dim)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
self.softmax_scale = self.head_dim ** -0.5
if self.compress_ratio:
self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
if self.compress_ratio == 4:
self.indexer = Indexer(args, self.compress_ratio)
else:
self.indexer = None
kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
if self.compress_ratio:
original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
else:
# disable YaRN and use base rope_theta in pure sliding-window attention
original_seq_len, rope_theta = 0, args.rope_theta
freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: torch.Tensor, start_pos: int):
bsz, seqlen, _ = x.size()
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
win = self.window_size
ratio = self.compress_ratio
rd = self.rope_head_dim
if self.compress_ratio and self.compressor.kv_cache is None:
self.compressor.kv_cache = self.kv_cache[:, win:]
self.compressor.freqs_cis = self.freqs_cis
if self.indexer is not None:
self.indexer.freqs_cis = self.freqs_cis
# q
qr = q = self.q_norm(self.wq_a(x))
q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
apply_rotary_emb(q[..., -rd:], freqs_cis)
# win kv & topk_idxs
kv = self.wkv(x)
kv = self.kv_norm(kv)
apply_rotary_emb(kv[..., -rd:], freqs_cis)
# FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
if self.compress_ratio:
offset = kv.size(1) if start_pos == 0 else win
if self.indexer is not None:
compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
else:
compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
topk_idxs = topk_idxs.int()
# compress kv & attn
if start_pos == 0:
if seqlen <= win:
self.kv_cache[:bsz, :seqlen] = kv
else:
cutoff = seqlen % win
self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
if self.compress_ratio:
if (kv_compress := self.compressor(x, start_pos)) is not None:
kv = torch.cat([kv, kv_compress], dim=1)
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
else:
self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
if self.compress_ratio:
self.compressor(x, start_pos)
o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
apply_rotary_emb(o[..., -rd:], freqs_cis, True)
# o
o = o.view(bsz, seqlen, self.n_local_groups, -1)
wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
# NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf,
# but using BF16 for simplicity.
o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
x = self.wo_b(o.flatten(2))
return x
class Gate(nn.Module):
"""MoE gating: computes expert routing scores and selects top-k experts.
Supports hash-based routing (first n_hash_layers) where expert indices are
predetermined per token ID, and score-based routing (remaining layers)."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.topk = args.n_activated_experts
self.score_func = args.score_func
self.route_scale = args.route_scale
self.hash = layer_id < args.n_hash_layers
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
if self.hash:
self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
self.bias = None
else:
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
scores = linear(x.float(), self.weight.float())
if self.score_func == "softmax":
scores = scores.softmax(dim=-1)
elif self.score_func == "sigmoid":
scores = scores.sigmoid()
else:
scores = F.softplus(scores).sqrt()
original_scores = scores
# Bias shifts scores for expert selection (topk) but does not affect routing weights.
if self.bias is not None:
scores = scores + self.bias
if self.hash:
indices = self.tid2eid[input_ids]
else:
indices = scores.topk(self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func != "softmax":
weights /= weights.sum(dim=-1, keepdim=True)
weights *= self.route_scale
return weights, indices
class Expert(nn.Module):
"""Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
super().__init__()
self.w1 = Linear(dim, inter_dim, dtype=dtype)
self.w2 = Linear(inter_dim, dim, dtype=dtype)
self.w3 = Linear(dim, inter_dim, dtype=dtype)
self.swiglu_limit = swiglu_limit
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
dtype = x.dtype
gate = self.w1(x).float()
up = self.w3(x).float()
if self.swiglu_limit > 0:
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
gate = torch.clamp(gate, max=self.swiglu_limit)
x = F.silu(gate) * up
if weights is not None:
x = weights * x
return self.w2(x.to(dtype))
class MoE(nn.Module):
"""Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id
self.dim = args.dim
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(layer_id, args)
expert_dtype = None
if args.expert_dtype == "fp4":
expert_dtype = torch.float4_e2m1fn_x2
elif args.expert_dtype == "fp8":
expert_dtype = torch.float8_e4m3fn
# expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=torch.float8_e4m3fn, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)])
assert args.n_shared_experts == 1
# no swiglu_limit
self.shared_experts = Expert(args.dim, args.moe_inter_dim)
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x, input_ids.flatten())
y = torch.zeros_like(x, dtype=torch.float32)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx], weights[idx, top, None])
if world_size > 1:
dist.all_reduce(y)
y += self.shared_experts(x)
return y.type_as(x).view(shape)
class Block(nn.Module):
"""Transformer block with Hyper-Connections (HC) mixing.
Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id
self.norm_eps = args.norm_eps
self.attn = Attention(layer_id, args)
self.ffn = MoE(layer_id, args)
self.attn_norm = RMSNorm(args.dim, self.norm_eps)
self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
self.hc_mult = hc_mult = args.hc_mult
self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
self.hc_eps = args.hc_eps
mix_hc = (2 + hc_mult) * hc_mult
hc_dim = hc_mult * args.dim
with set_dtype(torch.float32):
self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
self.hc_attn_scale = nn.Parameter(torch.empty(3))
self.hc_ffn_scale = nn.Parameter(torch.empty(3))
def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
# x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
shape, dtype = x.size(), x.dtype
x = x.flatten(2).float()
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
mixes = F.linear(x, hc_fn) * rsqrt
pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
return y.to(dtype), post, comb
def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
# x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
return y.type_as(x)
def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
residual = x
x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
x = self.attn_norm(x)
x = self.attn(x, start_pos)
x = self.hc_post(x, residual, post, comb)
residual = x
x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
x = self.ffn_norm(x)
x = self.ffn(x, input_ids)
x = self.hc_post(x, residual, post, comb)
return x
class ParallelHead(nn.Module):
def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.norm_eps = norm_eps
self.hc_eps = hc_eps
self.part_vocab_size = (vocab_size // world_size)
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
def get_logits(self, x):
return F.linear(x[:, -1].float(), self.weight)
def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
# x: [b,s,hc,d]
x = self.hc_head(x, hc_fn, hc_scale, hc_base)
logits = self.get_logits(norm(x))
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
return logits
def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
shape, dtype = x.size(), x.dtype
x = x.flatten(2).float()
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
mixes = F.linear(x, hc_fn) * rsqrt
pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
return y.to(dtype)
class MTPBlock(Block):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__(layer_id, args)
self.e_proj = Linear(args.dim, args.dim)
self.h_proj = Linear(args.dim, args.dim)
self.enorm = RMSNorm(args.dim, args.norm_eps)
self.hnorm = RMSNorm(args.dim, args.norm_eps)
self.norm = RMSNorm(args.dim, args.norm_eps)
self.hc_mult = hc_mult = args.hc_mult
hc_dim = hc_mult * args.dim
with set_dtype(torch.float32):
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
self.hc_head_scale = nn.Parameter(torch.empty(1))
self.embed: ParallelEmbedding = None
self.head: ParallelHead = None
@torch.inference_mode()
def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
# x: [b,s,hc,d]
assert self.embed is not None and self.head is not None
e = self.embed(input_ids)
e = self.enorm(e)
x = self.hnorm(x)
x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
x = super().forward(x, start_pos, input_ids)
logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
return logits
class Transformer(nn.Module):
"""Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
def __init__(self, args: ModelArgs):
global world_size, rank, default_dtype, scale_fmt, scale_dtype
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
super().__init__()
self.max_seq_len = args.max_seq_len
self.norm_eps = args.norm_eps
self.hc_eps = args.hc_eps
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim, self.norm_eps)
self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
self.mtp = torch.nn.ModuleList()
for layer_id in range(args.n_mtp_layers):
self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
self.mtp[-1].embed = self.embed
self.mtp[-1].head = self.head
self.hc_mult = hc_mult = args.hc_mult
hc_dim = hc_mult * args.dim
with set_dtype(torch.float32):
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
self.hc_head_scale = nn.Parameter(torch.empty(1))
@torch.inference_mode()
def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
h = self.embed(input_ids)
# Expand to hc_mult copies for Hyper-Connections
h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
for layer in self.layers:
h = layer(h, start_pos, input_ids)
logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
return logits
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.manual_seed(0)
args = ModelArgs(n_hash_layers=0)
x = torch.randint(0, args.vocab_size, (2, 128))
model = Transformer(args)
print(model(x).size())
for i in range(128, 150):
print(i, model(x[:, 0:1], i).size())
h = torch.randn(2, 128, args.hc_mult, args.dim)
mtp = model.mtp[0]
print(mtp(h, 0, x).size())
print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
#!/bin/bash
export NCCL_ALGO=Ring
export NCCL_PROTO=Simple
export MP=8
export CONFIG=config.json
torchrun --nproc-per-node ${MP} generate.py --ckpt-path /path/of/DeepSeek-V4-Flash-FP8-MP8 --config ${CONFIG} --interactive
# 模型唯一标识
modelCode=2397
# 模型名称
modelName=DeepSeek-V4
modelName=DeepSeek-V4
# 模型描述
modelDescription= DeepSeek-V4:迈向高效百万上下文智能。
modelDescription=DeepSeek-V4:迈向高效百万上下文智能。
# 运行过程
processType=推理
# 算法类别
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment