Commit 5e61101f authored by luopl's avatar luopl
Browse files

Initial commit

parents
MIT License
Copyright (c) 2023 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
## DeepSeek-V4
## 论文
[DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/DeepSeek_V4.pdf)
## 模型简介
DeepSeek-V4 系列的预览版本,包含两款强大的混合专家(Mixture-of-Experts, MoE)语言模型:DeepSeek-V4-Pro(总参数量 1.6T,激活参数量 49B)和 DeepSeek-V4-Flash(总参数量 284B,激活参数量 13B),两者均支持 百万 token 的上下文长度。
DeepSeek-V4 系列在架构与优化方面引入了多项关键升级:
- 混合注意力架构:设计了一种混合注意力机制,结合压缩稀疏注意力(Compressed Sparse Attention, CSA)与重度压缩注意力(Heavily Compressed Attention, HCA),显著提升长上下文处理效率。在百万 token 上下文场景下,DeepSeek-V4-Pro 相比 DeepSeek-V3.2 仅需 27% 的单 token 推理 FLOPs 和 10% 的 KV 缓存。
- 流形约束超连接(Manifold-Constrained Hyper-Connections, mHC):在传统残差连接基础上引入 mHC,增强跨层信号传播的稳定性,同时保留模型的表达能力。
- Muon 优化器:采用 Muon 优化器,以实现更快的收敛速度和更高的训练稳定性。
## 环境依赖
| 软件 | 版本 |
| :------: |:-------:|
| DTK | 26.04 |
| python | 3.10.12 |
| torch | 2.9.0+das.opt1.dtk2604.20260331.g4e3c1e7 |
| tilelang | 0.1.7.post3+cpu.git52700923 |
推荐使用镜像:harbor.sourcefind.cn:5443/dcu/admin/base/custom:torch-2.9.0-ubuntu22.04-dtk26.04-deepseek-v4-0425
- 挂载地址`-v`根据实际模型情况修改
```bash
docker run -it \
--shm-size 200g \
--network=host \
--name deepseek-v4 \
--privileged \
--device=/dev/kfd \
--device=/dev/dri \
--device=/dev/mkfd \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-u root \
-v /opt/hyhal/:/opt/hyhal/:ro \
-v /path/your_code_data/:/path/your_code_data/ \
harbor.sourcefind.cn:5443/dcu/admin/base/custom:torch-2.9.0-ubuntu22.04-dtk26.04-deepseek-v4-0425 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
## 数据集
`暂无`
## 训练
`暂无`
## 推理
### pytorch
#### 单机推理
1. 模型转换与切分
```bash
#注意将脚本中对应的路径及参数设置成用户实际值
#其中:INPUT_FP8_HF_PATH为模型下载路径;OUTPUT_BF16_HF_PATH为bf16模型存放路径;SAVE_PATH为切分好的模型路径;mp根据实际卡数调整
cd convert_weight
bash convert_weight.sh
```
2. 启动对话推理
```bash
cd ../inference
#注意将脚本中对应的路径及参数设置成用户实际值
sh start_torch.sh
```
## 效果展示
<div align=center>
<img src="./doc/result_dcu.png"/>
</div>
### 精度
`DCU与GPU精度一致,推理框架:pytorch。`
## 预训练权重
| 模型名称 | 权重大小 | DCU型号 | 最低卡数需求 |下载地址|
|:-----:|:----:|:------:|:------:|:----------:|
| DeepSeek-V4-Flash | 158B | BW1100 | 8 | [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) |
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/deepseek-v4
## 参考资料
- https://github.com/deepseek-ai
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
FP4_TABLE = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
], dtype=torch.float32)
def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Casts a tensor from e2m1fn to e4m3fn losslessly.
"""
assert x.dtype == torch.int8
assert x.ndim == 2
out_dim, in_dim = x.size()
in_dim *= 2
fp8_block_size = 128
fp4_block_size = 32
assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
x = x.view(torch.uint8)
low = x & 0x0F
high = (x >> 4) & 0x0F
x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
# max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
# 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
MAX_OFFSET_BITS = 6
bOut = out_dim // fp8_block_size
bIn = in_dim // fp8_block_size
# bOut, bIn, 128, 128
x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
# bOut, bIn, 128*4
scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
## bOut, bIn, 1
scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
# bOut, bIn, 128*4
offset = scale / scale_max_offset_bits
# bOut, bIn, 128, 128
offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"lm_head": ("head", 0),
"embed": ("embed", 0),
"wq_b": ("wq_b", 0),
"wo_a": ("wo_a", 0),
"wo_b": ("wo_b", 1),
"head": ("head", 0),
"attn_sink": ("attn_sink", 0),
"weights_proj": ("weights_proj", 0),
}
def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
"""
Converts and saves model checkpoint files into a specified format.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns:
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
continue
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
key = name.split(".")[-1]
else:
key = name.split(".")[-2]
if key in mapping:
new_key, dim = mapping[key]
else:
new_key, dim = key, None
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
changed=True
if not changed:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
else:
print(f"Processing parameter {name} with shape {param.shape} for model parallel shard {i}")
if "wo_a" not in name and "wo_b" not in name:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
else:
num_projection_groups = mp//8
new_mp = mp // num_projection_groups
new_i = i // num_projection_groups
shard_size = param.size(dim) // new_mp
assert(shard_size==1024)
new_param = param.narrow(dim, new_i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
names = list(state_dicts[i].keys())
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for file in ["tokenizer.json", "tokenizer_config.json"]:
old_file_path = os.path.join(hf_ckpt_path, file)
new_file_path = os.path.join(save_path, file)
shutil.copyfile(old_file_path, new_file_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
"""
DeepSeek-V3.2 FP8/FP4 -> BF16 Converter
Key differences from V3:
- Scale format: ue8m0 (unsigned E8M0, power-of-2 scales, may be stored as uint8)
- weight_dequant uses block-reshape approach (from V3.2 model.py:490)
- DSA (DeepSeek Attention): each layer has `indexer` submodule
- indexer.wq_b, indexer.wk: FP8 (have weight_scale_inv)
- indexer.k_norm (weight+bias), indexer.weights_proj: BF16 (no scale)
- MTP layer (layer 61): enorm, hnorm, eh_proj, shared_head.norm/head, embed_tokens
- No dependency on kernel.py (pure PyTorch)
- 62 layers (0-61), 163 shards, ~92k keys
- Experts weights use FP4 (MXFP4 E2M1, packed 2 per uint8, group_size=32, E8M0 scale)
"""
import os
import json
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
BLOCK_SIZE = 128
FP4_GROUP_SIZE = 32
# E2M1 FP4 lookup table: 4-bit index -> float value
# Bit layout: sign(1) | exponent(2) | mantissa(1), bias=1
_FP4_E2M1_LUT = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
], dtype=torch.bfloat16)
def dequant_fp4_weight(weight_packed: torch.Tensor, scale_fp8: torch.Tensor) -> torch.Tensor:
"""
Dequantize MXFP4 weight to bf16 by unpacking FP4 E2M1 nibbles via LUT.
Each int8 byte stores 2 FP4 values (low nibble + high nibble).
float4_e2m1fn_x2 does NOT support .to(bfloat16), so we manually unpack.
Scale is E8M0, one per group of 32 elements along the last dimension.
Args:
weight_packed: [out_features, in_features/2], int8, each byte = 2 FP4 values
scale_fp8: [out_features, in_features/32], float8_e8m0fnu, E8M0 scale per group of 32
Returns:
bf16 tensor [out_features, in_features]
"""
out_features, packed_in = weight_packed.shape
in_features = packed_in * 2
# Unpack two FP4 values per byte via nibble extraction + LUT
raw = weight_packed.to(torch.uint8)
low_nibble = (raw & 0x0F).to(torch.long)
high_nibble = (raw >> 4).to(torch.long)
lut = _FP4_E2M1_LUT.to(weight_packed.device)
low_vals = lut[low_nibble]
high_vals = lut[high_nibble]
# Interleave: [low_0, high_0, low_1, high_1, ...]
fp4_values = torch.stack([low_vals, high_vals], dim=-1).reshape(out_features, in_features)
# Decode E8M0 scale and expand to match fp4_values
scale = decode_e8m0_scale(scale_fp8)
if scale.dim() == 2 and scale.shape[0] == out_features:
# Scale already shaped [out_features, num_groups_per_row]
num_groups_per_row = scale.shape[1]
else:
# Flat scale: compute groups per row from total count
total_scales = scale.numel()
num_groups_per_row = total_scales // out_features
scale = scale.reshape(out_features, num_groups_per_row)
actual_group_size = in_features // num_groups_per_row
scale = scale.unsqueeze(-1).expand(-1, -1, actual_group_size).reshape(out_features, in_features)
return (fp4_values * scale).to(torch.bfloat16)
def is_expert_weight(name: str) -> bool:
"""Check if a weight belongs to an expert (MoE) layer, excluding shared_experts."""
return "experts" in name and "shared_experts" not in name
def decode_e8m0_scale(scale: torch.Tensor) -> torch.Tensor:
"""
Decode E8M0 (unsigned 8-bit exponent-only) scale to float32.
E8M0 stores only the exponent: value = 2^(exp - 127), same as IEEE 754 exponent encoding.
If scale is already float32, return as-is.
"""
if scale.dtype == torch.float32:
return scale
if scale.dtype in (torch.bfloat16, torch.float16):
return scale.float()
# float8_e8m0fnu: .to(int32) does value conversion not bit reinterpret,
# so use native .float() which handles E8M0 correctly
if scale.dtype == torch.float8_e8m0fnu:
return scale.float()
# uint8 / int8 raw E8M0 bytes: interpret as IEEE 754 exponent
if scale.element_size() == 1:
# Reconstruct float32 from exponent: set exponent bits in IEEE 754 float32
# float32 = sign(1) + exponent(8) + mantissa(23)
# E8M0 value stored as raw exponent byte -> float = 2^(byte - 127)
exp_bits = scale.to(torch.int32) << 23
return exp_bits.view(torch.float32)
return scale.float()
def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = BLOCK_SIZE) -> torch.Tensor:
"""
Dequantize FP8 weight to BF16 using block-wise scale.
Based on V3.2 model.py:490-495, works on both CPU and CUDA.
Each (block_size x block_size) block of the weight is multiplied by one scale value.
Handles non-aligned dimensions (e.g. kv_a_proj_with_mqa shape 576x7168, 576 % 128 != 0)
by padding to the nearest multiple of block_size, dequantizing, then trimming.
"""
shape = weight.shape
assert weight.dim() == 2, f"Expected 2D weight, got {weight.dim()}D"
M, N = shape
# Decode E8M0 scale if needed
scale = decode_e8m0_scale(scale)
# Pad to nearest multiple of block_size if needed
pad_m = (block_size - M % block_size) % block_size
pad_n = (block_size - N % block_size) % block_size
if pad_m or pad_n:
weight = torch.nn.functional.pad(weight, (0, pad_n, 0, pad_m))
Mp, Np = weight.shape
# V3.2 dequant: reshape into blocks, scale, reshape back
weight = weight.view(
Mp // block_size, block_size,
Np // block_size, block_size
).transpose(1, 2).contiguous().view(-1, block_size * block_size)
weight = (weight.float() * scale.reshape(-1, 1)).to(torch.bfloat16)
weight = weight.view(
Mp // block_size, Np // block_size,
block_size, block_size
).transpose(1, 2).contiguous().view(Mp, Np)
# Trim padding
if pad_m or pad_n:
weight = weight[:M, :N]
return weight
def main(fp8_path, bf16_path, device="cuda"):
torch.set_default_dtype(torch.bfloat16)
if device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
device = "cpu"
os.makedirs(bf16_path, exist_ok=True)
# 1. Copy non-safetensor files (config.json, tokenizer, etc.)
print("Copying auxiliary files...")
for file_path in glob(os.path.join(fp8_path, "*")):
fname = os.path.basename(file_path)
if fname.endswith(".safetensors") or fname == "model.safetensors.index.json":
continue
dst = os.path.join(bf16_path, fname)
if os.path.isfile(file_path):
shutil.copy2(file_path, dst)
print(f" Copied {fname}")
# 2. Load model index
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# 3. Pre-build scale_inv lookup: weight_name -> scale_inv_name
# V3.2 naming: "xxx.weight" -> "xxx.weight_scale_inv"
scale_inv_map = {}
all_scale_names = set()
for name in weight_map:
if name.endswith("scale"):
weight_name = name[:-len("scale")] + "weight"
if weight_name in weight_map:
all_scale_names.add(name)
scale_inv_map[weight_name] = name
# Separate expert (FP4) vs non-expert (FP8) scale mappings
fp4_scale_map = {k: v for k, v in scale_inv_map.items() if is_expert_weight(k)}
fp8_scale_map = {k: v for k, v in scale_inv_map.items() if not is_expert_weight(k)}
print(f"Model: DeepSeek (DeepseekForCausalLM)")
print(f"Device: {device}")
print(f"Total keys in index: {len(weight_map)}")
print(f"FP4 expert weights with scale: {len(fp4_scale_map)}")
print(f"FP8 non-expert weights with scale: {len(fp8_scale_map)}")
print(f"Scale entries: {len(all_scale_names)}")
# Cache for loaded safetensor files on CPU (handles cross-shard scale lookups)
# All shards are loaded to CPU to avoid GPU OOM; individual tensors are moved to
# GPU for dequant one at a time.
loaded_files = {}
use_cuda = (device == "cuda")
def get_tensor(tensor_name):
"""Load tensor from the correct shard file (CPU), with caching."""
file_name = weight_map.get(tensor_name)
if file_name is None:
raise KeyError(tensor_name)
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cpu")
return loaded_files[file_name][tensor_name]
# 4. Process safetensor files one by one
safetensor_files = sorted(glob(os.path.join(fp8_path, "*.safetensors")))
converted_count = 0
kept_count = 0
for safetensor_file in tqdm(safetensor_files, desc="Converting FP8 -> BF16"):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cpu")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
# Skip scale_inv tensors (will be removed from output)
if weight_name in all_scale_names:
continue
if weight.element_size() == 1 and weight_name in scale_inv_map:
scale_inv_name = scale_inv_map[weight_name]
try:
scale_inv = get_tensor(scale_inv_name)
if is_expert_weight(weight_name):
# FP4 expert weight -> dequantize using MXFP4 logic
# print(f" FP4 dequant: {weight_name}, weight={weight.shape}, scale={scale_inv.shape}, dtype={scale_inv.dtype}")
if use_cuda:
result = dequant_fp4_weight(weight.cuda(), scale_inv.cuda())
new_state_dict[weight_name] = result.cpu()
else:
new_state_dict[weight_name] = dequant_fp4_weight(weight, scale_inv)
else:
# FP8 non-expert weight -> dequantize using block-wise scale
if use_cuda:
result = weight_dequant(weight.cuda(), scale_inv.cuda())
new_state_dict[weight_name] = result.cpu()
else:
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
converted_count += 1
except KeyError:
print(f" Warning: scale '{scale_inv_name}' not loadable for {weight_name}, keeping raw")
new_state_dict[weight_name] = weight
else:
# BF16/FP32 weights: norms, biases, gate.weight, indexer.weights_proj,
# indexer.k_norm, MTP layers (enorm, hnorm, eh_proj, shared_head), embed, lm_head, etc.
new_state_dict[weight_name] = weight
kept_count += 1
# Save converted shard
save_file(new_state_dict, os.path.join(bf16_path, file_name))
# Memory management: keep at most 2 cached shard files on CPU
# (needed for cross-shard scale lookups, e.g. weight in shard N, scale in shard N+1)
while len(loaded_files) > 2:
oldest = next(iter(loaded_files))
del loaded_files[oldest]
# 5. Update model index: remove all scale_inv entries
new_weight_map = {k: v for k, v in weight_map.items() if k not in all_scale_names}
new_index = {
"metadata": model_index.get("metadata", {}),
"weight_map": new_weight_map,
}
with open(os.path.join(bf16_path, "model.safetensors.index.json"), "w") as f:
json.dump(new_index, f, indent=2)
print(f"\nDone!")
print(f" FP8/FP4 -> BF16 converted: {converted_count}")
print(f" Already BF16/FP32 (kept as-is): {kept_count}")
print(f" Scale entries removed: {len(all_scale_names)}")
print(f" Output keys: {len(new_weight_map)} (was {len(weight_map)})")
print(f" Output saved to: {bf16_path}")
if __name__ == "__main__":
parser = ArgumentParser(description="Convert DeepSeek-V3.2 FP8/FP4 checkpoint to BF16")
parser.add_argument("--input-fp8-hf-path", type=str, required=True,
help="Path to the FP8 HuggingFace model directory (DeepSeek-V3.2)")
parser.add_argument("--output-bf16-hf-path", type=str, required=True,
help="Path to the output BF16 model directory")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
help="Device for dequantization (default: cuda)")
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.device)
#!/usr/bin/env bash
set -euo pipefail
# Step 1: fp4/fp8 -> bf16
INPUT_FP8_HF_PATH="deepseek-ai/DeepSeek-V4-Flash"
OUTPUT_BF16_HF_PATH="deepseek-ai/DeepSeek-V4-Flash-bf16"
python3 convert_weight.py \
--input-fp8-hf-path "${INPUT_FP8_HF_PATH}" \
--output-bf16-hf-path "${OUTPUT_BF16_HF_PATH}"
# Step 2: bf16 -> bf16-mp16
MP=8
HF_CKPT_BF16_PATH="${OUTPUT_BF16_HF_PATH}"
SAVE_PATH="deepseek-ai/DeepSeek-V4-Flash-bf16-mp16"
EXPERTS=256
python3 convert.py \
--hf-ckpt-path "${HF_CKPT_BF16_PATH}" \
--save-path "${SAVE_PATH}" \
--n-experts "${EXPERTS}" \
--model-parallel "${MP}"
icon.png

53.8 KB

# Inference code for DeepSeek models
First convert huggingface model weight files to the format of this project.
```bash
export EXPERTS=256
export MP=4
export CONFIG=config.json
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
```
Then chat with DeepSeek model at will!
```bash
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
```
Or batch inference from file.
```bash
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
```
Or multi nodes inference.
```bash
torchrun --nnodes ${NODES} --nproc-per-node $((MP / NODES)) --node-rank $RANK --master-addr $ADDR generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
```
If you want to use fp8, just remove `"expert_dtype": "fp4"` in `config.json` and specify `--expert-dtype fp8` in `convert.py`.
{
"vocab_size": 129280,
"dim": 4096,
"moe_inter_dim": 2048,
"n_layers": 43,
"n_hash_layers": 3,
"n_heads": 64,
"n_routed_experts": 256,
"n_shared_experts": 1,
"n_activated_experts": 6,
"score_func": "sqrtsoftplus",
"route_scale": 1.5,
"swiglu_limit": 10.0,
"q_lora_rank": 1024,
"head_dim": 512,
"rope_head_dim": 64,
"o_groups": 8,
"o_lora_rank": 1024,
"window_size": 128,
"original_seq_len": 65536,
"rope_theta": 10000,
"rope_factor": 16,
"beta_fast": 32,
"beta_slow": 1,
"index_n_heads": 64,
"index_head_dim": 128,
"index_topk": 512,
"hc_mult": 4,
"hc_sinkhorn_iters": 20,
"dtype": "fp8",
"scale_fmt": "ue8m0",
"compress_rope_theta": 160000,
"compress_ratios": [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
}
\ No newline at end of file
"""
DeepSeek-V4 Encoding
A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
with tool calling, thinking mode, and quick instruction task support.
"""
from typing import Any, Dict, List, Union, Optional, Tuple
import copy
import json
import re
# ============================================================
# Special Tokens
# ============================================================
bos_token: str = "<|begin▁of▁sentence|>"
eos_token: str = "<|end▁of▁sentence|>"
thinking_start_token: str = "<think>"
thinking_end_token: str = "</think>"
dsml_token: str = "|DSML|"
USER_SP_TOKEN = "<|User|>"
ASSISTANT_SP_TOKEN = "<|Assistant|>"
LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
# Task special tokens for internal classification tasks
DS_TASK_SP_TOKENS = {
"action": "<|action|>",
"query": "<|query|>",
"authority": "<|authority|>",
"domain": "<|domain|>",
"title": "<|title|>",
"read_url": "<|read_url|>",
}
VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
# ============================================================
# Templates
# ============================================================
system_msg_template: str = "{content}"
user_msg_template: str = "{content}"
latest_reminder_msg_template: str = "{content}"
assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
thinking_template: str = "{reasoning_content}"
response_format_template: str = (
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
)
tool_call_template: str = (
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
)
tool_calls_template = (
"<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
)
tool_calls_block_name: str = "tool_calls"
tool_output_template: str = (
"<tool_result>{content}</tool_result>"
)
REASONING_EFFORT_MAX = (
"Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
"You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
"Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
)
TOOLS_TEMPLATE = """## Tools
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
<{dsml_token}tool_calls>
<{dsml_token}invoke name="$TOOL_NAME">
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
...
</{dsml_token}invoke>
<{dsml_token}invoke name="$TOOL_NAME2">
...
</{dsml_token}invoke>
</{dsml_token}tool_calls>
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
Otherwise, output directly after {thinking_end_token} with tool calls or final response.
### Available Tool Schemas
{tool_schemas}
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
"""
# ============================================================
# Utility Functions
# ============================================================
def to_json(value: Any) -> str:
"""Serialize a value to JSON string."""
try:
return json.dumps(value, ensure_ascii=False)
except:
return json.dumps(value, ensure_ascii=True)
def tools_from_openai_format(tools):
"""Extract function definitions from OpenAI-format tool list."""
return [tool["function"] for tool in tools]
def tool_calls_from_openai_format(tool_calls):
"""Convert OpenAI-format tool calls to internal format."""
return [
{
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
for tool_call in tool_calls
]
def tool_calls_to_openai_format(tool_calls):
"""Convert internal tool calls to OpenAI format."""
return [
{
"type": "function",
"function": {
"name": tool_call["name"],
"arguments": tool_call["arguments"],
}
}
for tool_call in tool_calls
]
def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
"""
Encode tool call arguments into DSML parameter format.
Args:
tool_call: Dict with "name" and "arguments" (JSON string) keys.
Returns:
DSML-formatted parameter string.
"""
p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
P_dsml_strs = []
try:
arguments = json.loads(tool_call["arguments"])
except Exception as err:
arguments = {"arguments": tool_call["arguments"]}
for k, v in arguments.items():
p_dsml_str = p_dsml_template.format(
dsml_token=dsml_token,
key=k,
is_str="true" if isinstance(v, str) else "false",
value=v if isinstance(v, str) else to_json(v),
)
P_dsml_strs.append(p_dsml_str)
return "\n".join(P_dsml_strs)
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
"""
Decode DSML parameters back to a tool call dict.
Args:
tool_name: Name of the tool.
tool_args: Dict mapping param_name -> (value, is_string_flag).
Returns:
Dict with "name" and "arguments" (JSON string) keys.
"""
def _decode_value(key: str, value: str, string: str):
if string == "true":
value = to_json(value)
return f"{to_json(key)}: {value}"
tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
return dict(name=tool_name, arguments=tool_args_json)
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
"""
Render tool schemas into the system prompt format.
Args:
tools: List of tool schema dicts (each with name, description, parameters).
Returns:
Formatted tools section string.
"""
tools_json = [to_json(t) for t in tools]
return TOOLS_TEMPLATE.format(
tool_schemas="\n".join(tools_json),
dsml_token=dsml_token,
thinking_start_token=thinking_start_token,
thinking_end_token=thinking_end_token,
)
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
"""Find the index of the last user/developer message."""
last_user_index = -1
for idx in range(len(messages) - 1, -1, -1):
if messages[idx].get("role") in ["user", "developer"]:
last_user_index = idx
break
return last_user_index
# ============================================================
# Message Rendering
# ============================================================
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
"""
Render a single message at the given index into its encoded string form.
This is the core function that converts each message in the conversation
into the DeepSeek-V4 format.
Args:
index: Index of the message to render.
messages: Full list of messages in the conversation.
thinking_mode: Either "chat" or "thinking".
drop_thinking: Whether to drop reasoning content from earlier turns.
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
Returns:
Encoded string for this message.
"""
assert 0 <= index < len(messages)
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
prompt = ""
msg = messages[index]
last_user_idx = find_last_user_index(messages)
role = msg.get("role")
content = msg.get("content")
tools = msg.get("tools")
response_format = msg.get("response_format")
tool_calls = msg.get("tool_calls")
reasoning_content = msg.get("reasoning_content")
wo_eos = msg.get("wo_eos", False)
if tools:
tools = tools_from_openai_format(tools)
if tool_calls:
tool_calls = tool_calls_from_openai_format(tool_calls)
# Reasoning effort prefix (only at index 0 in thinking mode with max effort)
assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
prompt += REASONING_EFFORT_MAX
if role == "system":
prompt += system_msg_template.format(content=content or "")
if tools:
prompt += "\n\n" + render_tools(tools)
if response_format:
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
elif role == "developer":
assert content, f"Invalid message for role `{role}`: {msg}"
content_developer = USER_SP_TOKEN
content_developer += content
if tools:
content_developer += "\n\n" + render_tools(tools)
if response_format:
content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
prompt += user_msg_template.format(content=content_developer)
elif role == "user":
prompt += USER_SP_TOKEN
# Handle content blocks (tool results mixed with text)
content_blocks = msg.get("content_blocks")
if content_blocks:
parts = []
for block in content_blocks:
block_type = block.get("type")
if block_type == "text":
parts.append(block.get("text", ""))
elif block_type == "tool_result":
tool_content = block.get("content", "")
if isinstance(tool_content, list):
text_parts = []
for b in tool_content:
if b.get("type") == "text":
text_parts.append(b.get("text", ""))
else:
text_parts.append(f"[Unsupported {b.get('type')}]")
tool_content = "\n\n".join(text_parts)
parts.append(tool_output_template.format(content=tool_content))
else:
parts.append(f"[Unsupported {block_type}]")
prompt += "\n\n".join(parts)
else:
prompt += content or ""
elif role == "latest_reminder":
prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
elif role == "tool":
raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
elif role == "assistant":
thinking_part = ""
tc_content = ""
if tool_calls:
tc_list = [
tool_call_template.format(
dsml_token=dsml_token,
name=tc.get("name"),
arguments=encode_arguments_to_dsml(tc)
)
for tc in tool_calls
]
tc_content += '\n\n' + tool_calls_template.format(
dsml_token=dsml_token,
tool_calls="\n".join(tc_list),
tc_block_name=tool_calls_block_name,
)
summary_content = content or ""
rc = reasoning_content or ""
# Check if previous message has a task - if so, this is a task output (no thinking)
prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
if thinking_mode == "thinking" and not prev_has_task:
if not drop_thinking or index > last_user_idx:
thinking_part = thinking_template.format(reasoning_content=rc) + thinking_end_token
else:
thinking_part = ""
if wo_eos:
prompt += assistant_msg_wo_eos_template.format(
reasoning=thinking_part,
content=summary_content,
tool_calls=tc_content,
)
else:
prompt += assistant_msg_template.format(
reasoning=thinking_part,
content=summary_content,
tool_calls=tc_content,
)
else:
raise NotImplementedError(f"Unknown role: {role}")
# Append transition tokens based on what follows
if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
return prompt
task = messages[index].get("task")
if task is not None:
# Task special token for internal classification tasks
assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
task_sp_token = DS_TASK_SP_TOKENS[task]
if task != "action":
# Non-action tasks: append task sp token directly after the message
prompt += task_sp_token
else:
# Action task: append Assistant + thinking token + action sp token
prompt += ASSISTANT_SP_TOKEN
prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
prompt += task_sp_token
elif messages[index].get("role") in ["user", "developer"]:
# Normal generation: append Assistant + thinking token
prompt += ASSISTANT_SP_TOKEN
if not drop_thinking and thinking_mode == "thinking":
prompt += thinking_start_token
elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
prompt += thinking_start_token
else:
prompt += thinking_end_token
return prompt
# ============================================================
# Preprocessing
# ============================================================
def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Merge tool messages into the preceding user message using content_blocks format.
DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
are encoded as <tool_result> blocks within user messages.
This function converts a standard OpenAI-format conversation (with separate
"tool" role messages) into V4 format where tool results are merged into
user messages.
Args:
messages: List of message dicts in OpenAI format.
Returns:
Processed message list with tool messages merged into user messages.
"""
merged: List[Dict[str, Any]] = []
for msg in messages:
msg = copy.deepcopy(msg)
role = msg.get("role")
if role == "tool":
# Convert tool message to a user message with tool_result block
tool_block = {
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": msg.get("content", ""),
}
# Merge into previous message if it's already a user (merged tool)
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
merged[-1]["content_blocks"].append(tool_block)
else:
merged.append({
"role": "user",
"content_blocks": [tool_block],
})
elif role == "user":
text_block = {"type": "text", "text": msg.get("content", "")}
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
merged[-1]["content_blocks"].append(text_block)
else:
new_msg = {
"role": "user",
"content": msg.get("content", ""),
"content_blocks": [text_block],
}
# Preserve extra fields (task, wo_eos, mask, etc.)
for key in ("task", "wo_eos", "mask"):
if key in msg:
new_msg[key] = msg[key]
merged.append(new_msg)
else:
merged.append(msg)
return merged
def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Sort tool_result blocks within user messages by the order of tool_calls
in the preceding assistant message.
Args:
messages: Preprocessed message list (after merge_tool_messages).
Returns:
Message list with sorted tool result blocks.
"""
last_tool_call_order: Dict[str, int] = {}
for msg in messages:
role = msg.get("role")
if role == "assistant" and msg.get("tool_calls"):
last_tool_call_order = {}
for idx, tc in enumerate(msg["tool_calls"]):
tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
if tc_id:
last_tool_call_order[tc_id] = idx
elif role == "user" and msg.get("content_blocks"):
tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
if len(tool_blocks) > 1 and last_tool_call_order:
sorted_blocks = sorted(
tool_blocks,
key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
)
sorted_idx = 0
new_blocks = []
for block in msg["content_blocks"]:
if block.get("type") == "tool_result":
new_blocks.append(sorted_blocks[sorted_idx])
sorted_idx += 1
else:
new_blocks.append(block)
msg["content_blocks"] = new_blocks
return messages
# ============================================================
# Main Encoding Function
# ============================================================
def encode_messages(
messages: List[Dict[str, Any]],
thinking_mode: str,
context: Optional[List[Dict[str, Any]]] = None,
drop_thinking: bool = True,
add_default_bos_token: bool = True,
reasoning_effort: Optional[str] = None,
) -> str:
"""
Encode a list of messages into the DeepSeek-V4 prompt format.
This is the main entry point for encoding conversations. It handles:
- BOS token insertion
- Thinking mode with optional reasoning content dropping
- Tool message merging into user messages
- Multi-turn conversation context
Args:
messages: List of message dicts to encode.
thinking_mode: Either "chat" or "thinking".
context: Optional preceding context messages (already encoded prefix).
drop_thinking: If True, drop reasoning_content from earlier assistant turns
(only keep reasoning for messages after the last user message).
add_default_bos_token: Whether to prepend BOS token at conversation start.
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
Returns:
The encoded prompt string.
"""
context = context if context else []
# Preprocess: merge tool messages and sort tool results
messages = merge_tool_messages(messages)
messages = sort_tool_results_by_call_order(context + messages)[len(context):]
if context:
context = merge_tool_messages(context)
context = sort_tool_results_by_call_order(context)
full_messages = context + messages
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
# Resolve drop_thinking: if any message has tools defined, don't drop thinking
effective_drop_thinking = drop_thinking
if any(m.get("tools") for m in full_messages):
effective_drop_thinking = False
if thinking_mode == "thinking" and effective_drop_thinking:
full_messages = _drop_thinking_messages(full_messages)
# After dropping, recalculate how many messages to render
# (context may have shrunk too)
num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
context_len = len(full_messages) - num_to_render
else:
num_to_render = len(messages)
context_len = len(context)
for idx in range(num_to_render):
prompt += render_message(
idx + context_len,
full_messages,
thinking_mode=thinking_mode,
drop_thinking=effective_drop_thinking,
reasoning_effort=reasoning_effort,
)
return prompt
def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Drop reasoning_content and non-essential messages before the last user message.
Behavior:
- Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
- Messages at or after the last user index are always kept.
- Assistant messages before the last user get reasoning_content removed.
- Developer messages before the last user are dropped entirely.
"""
last_user_idx = find_last_user_index(messages)
result = []
keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
for idx, msg in enumerate(messages):
role = msg.get("role")
if role in keep_roles or idx >= last_user_idx:
result.append(msg)
elif role == "assistant":
msg = copy.copy(msg)
msg.pop("reasoning_content", None)
result.append(msg)
# developer and other roles before last_user_idx are dropped
return result
# ============================================================
# Parsing (Decoding model output)
# ============================================================
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
"""
Read text from index until one of the stop strings is found.
Returns:
Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
"""
min_pos = len(text)
matched_stop = None
for s in stop:
pos = text.find(s, index)
if pos != -1 and pos < min_pos:
min_pos = pos
matched_stop = s
if matched_stop:
content = text[index:min_pos]
return min_pos + len(matched_stop), content, matched_stop
else:
content = text[index:]
return len(text), content, None
def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
"""
Parse DSML tool calls from text starting at the given index.
Args:
index: Starting position in text.
text: The full text to parse.
Returns:
Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
Each tool call dict has "name" and "arguments" keys.
"""
tool_calls: List[Dict[str, Any]] = []
stop_token = None
tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
while index < len(text):
index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
if _ != ">\n":
raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'")
if stop_token == tool_calls_end_token:
break
if stop_token is None:
raise ValueError("Missing special token in tool calls")
index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
if len(p_tool_name) != 1:
raise ValueError(f"Tool name format error: '{tool_name_content}'")
tool_name = p_tool_name[0]
tool_args: Dict[str, Tuple[str, str]] = {}
while stop_token == f"<{dsml_token}parameter":
index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
if len(param_kv) != 1:
raise ValueError(f"Parameter format error: '{param_content}'")
param_name, string, param_value = param_kv[0]
if param_name in tool_args:
raise ValueError(f"Duplicate parameter name: '{param_name}'")
tool_args[param_name] = (param_value, string)
index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
if content != ">\n":
raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
tool_calls.append(tool_call)
return index, stop_token, tool_calls
def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
"""
Parse a model completion text into a structured assistant message.
This function takes the raw text output from the model (a single assistant turn)
and extracts:
- reasoning_content (thinking block)
- content (summary/response)
- tool_calls (if any)
NOTE: This function is designed to parse only correctly formatted strings and
will raise ValueError for malformed output.
Args:
text: The raw completion text (including EOS token).
thinking_mode: Either "chat" or "thinking".
Returns:
Dict with keys: "role", "content", "reasoning_content", "tool_calls".
tool_calls are in OpenAI format.
"""
summary_content, reasoning_content, tool_calls = "", "", []
index, stop_token = 0, None
tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
is_thinking = thinking_mode == "thinking"
is_tool_calling = False
if is_thinking:
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
reasoning_content = content_delta
assert stop_token == thinking_end_token, "Invalid thinking format: missing </think>"
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
summary_content = content_delta
if stop_token == tool_calls_start_token:
is_tool_calling = True
else:
assert stop_token == eos_token, "Invalid format: missing EOS token"
if is_tool_calling:
index, stop_token, tool_calls = parse_tool_calls(index, text)
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
assert not tool_ends_text, "Unexpected content after tool calls"
assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
assert sp_token not in summary_content and sp_token not in reasoning_content, \
f"Unexpected special token '{sp_token}' in content"
return {
"role": "assistant",
"content": summary_content,
"reasoning_content": reasoning_content,
"tool_calls": tool_calls_to_openai_format(tool_calls)
}
import os
import json
import sys
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
current_dir = os.path.dirname(os.path.abspath(__file__))
encoding_dir = os.path.join(current_dir, '../encoding')
sys.path.insert(0, os.path.abspath(encoding_dir))
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
def sample(logits, temperature: float = 1.0):
"""Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
since it avoids the GPU-to-CPU sync in torch.multinomial."""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""Batch generation with left-padded prompts.
The first forward pass processes [min_prompt_len:] tokens (prefill phase).
Subsequent passes generate one token at a time (decode phase). For positions
still within a prompt, the ground-truth token overrides the model's prediction.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens))
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
toks.append(eos_id)
completion_tokens.append(toks)
return completion_tokens
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(33377335)
with open(config) as f:
args = ModelArgs(**json.load(f))
if interactive:
args.max_batch_size = 1
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print("load model")
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
torch.set_default_device("cuda")
print("I'm DeepSeek 👋")
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0])
print(completion)
messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
else:
with open(input_file) as f:
prompts = f.read().split("\n\n")
prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=300)
parser.add_argument("--temperature", type=float, default=0.6)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
import torch
import tilelang
import tilelang.language as T
from typing import Tuple, Optional
tilelang.set_log_level("WARNING")
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
}
FP8 = "float8_e4m3"
FP4 = "float4_e2m1fn"
# FE8M0 = "float8_e8m0fnu"
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
def fast_log2_ceil(x):
"""Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
"""Compute 2^x for integer x via IEEE 754 bit manipulation."""
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
round_scale=False, inplace=False
):
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale or inplace else 2
blk_m = 32
group_size = block_size
# Internal computation in FP32; scale_dtype controls output storage format.
compute_dtype = FP32
out_dtype = in_dtype if inplace else out_dtype
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
s_local = T.alloc_fragment((blk_m,), compute_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
if inplace:
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.Cast(
out_dtype,
T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
))) * s_local[i],
)
else:
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
) -> torch.Tensor:
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
N = x.size(-1)
assert N % block_size == 0
# tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
tl_dtype = FP32
z = x.contiguous()
y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
kernel = act_quant_kernel(
N, block_size, scale_dtype=tl_dtype,
round_scale=scale_fmt is not None, inplace=inplace,
)
kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
if inplace:
x.copy_(y)
return x
return y, s
@tilelang.jit(pass_configs=pass_configs)
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
assert out_dtype in [BF16, FP32]
M = T.symbolic("M")
group_size = 128
block_M = 32
block_N = 128
block_K = 128
@T.prim_func
def fp8_gemm_kernel_(
A: T.Tensor[(M, K), FP8],
B: T.Tensor[(N, K), FP8],
C: T.Tensor[(M, N), out_dtype],
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), FP8)
B_shared = T.alloc_shared((block_N, block_K), FP8)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
Scale_C_shared = T.alloc_shared((block_M), FP32)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=2):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
# Cast scales to FP32 for computation; scales_b has one value per block_N group
Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
for i in T.Parallel(block_M):
Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Separate accumulator for scale-corrected results (2x accumulation precision)
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return fp8_gemm_kernel_
def fp8_gemm(
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
scale_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
assert a_s.is_contiguous() and b_s.is_contiguous(), (
"Scaling factor tensors must be contiguous"
)
# tl_dtype = FP32
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
return c
@tilelang.jit(pass_configs=pass_configs)
def sparse_attn_kernel(h_orig: int, d: int, scale=None):
"""Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
with numerically stable running max/sum, and includes a learnable attn_sink bias."""
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
topk = T.symbolic("topk")
if scale is None:
scale = (1.0 / d) ** 0.5
num_stages = 0
threads = 256
block = 32
num_blocks = tilelang.cdiv(topk, block)
padded_H = max(tilelang.math.next_power_of_2(h_orig), 16)
max_block_m = 16
if h_orig > max_block_m:
assert h_orig % max_block_m == 0, f"h should be a multiple of {max_block_m}"
REPLICATE_H = h_orig // max_block_m
else:
REPLICATE_H = 1
h = padded_H if REPLICATE_H == 1 else max_block_m
@T.prim_func
def sparse_attn_kernel_(
q: T.Tensor[(b, m, h_orig, d), BF16],
kv: T.Tensor[(b, n, d), BF16],
o: T.Tensor[(b, m, h_orig, d), BF16],
attn_sink: T.Tensor[(h_orig,), FP32],
topk_idxs: T.Tensor[(b, m, topk), INT32],
):
with T.Kernel(m * REPLICATE_H, b, threads=threads) as (bx, by):
q_shared = T.alloc_fragment((h, d), BF16)
kv_shared = T.alloc_shared((block, d), BF16)
# o_shared = T.alloc_shared((h, d), BF16)
acc_s_cast = T.alloc_shared((h, block), BF16)
idxs = T.alloc_fragment(block, INT32)
acc_s = T.alloc_fragment((h, block), FP32)
acc_o = T.alloc_fragment((h, d), FP32)
scores_max = T.alloc_fragment(h, FP32)
scores_max_prev = T.alloc_fragment(h, FP32)
scores_scale = T.alloc_fragment(h, FP32)
scores_sum = T.alloc_fragment(h, FP32)
sum_exp = T.alloc_fragment(h, FP32)
T.clear(acc_o)
T.clear(sum_exp)
T.fill(scores_max, -T.infinity(FP32))
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
H0 = (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * h)
H1 = H0 + h
T.copy(q[by, s_i, H0:H1, :], q_shared)
for t in T.Pipelined(num_blocks, num_stages=num_stages):
for i in T.Parallel(block):
idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, s_i, t * block + i], -1)
for i, j in T.Parallel(block, d):
kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
for i, j in T.Parallel(h, block):
acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(h, block):
acc_s[i, j] *= scale
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(h):
scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
for i, j in T.Parallel(h, block):
acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(h):
sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(h, d):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(h):
sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
for i, j in T.Parallel(h, d):
acc_o[i, j] /= sum_exp[i]
o_shared = T.alloc_shared((h, d), BF16)
T.copy(acc_o, o_shared)
T.copy(o_shared, o[by, s_i, H0:H1, :])
return sparse_attn_kernel_
def sparse_attn(
q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
) -> torch.Tensor:
b, s, h, d = q.size()
# print(f"Teng {q.size()=}")
# Pad heads to 16 for kernel efficiency (stripped after)
if h < 16:
q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
o = torch.empty_like(q)
kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
kernel(q, kv, o, attn_sink, topk_idxs)
if h < 16:
o = o.narrow(2, 0, h).contiguous()
return o
@tilelang.jit(pass_configs=pass_configs)
def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
n = T.symbolic("n")
mix_hc = (2 + hc) * hc
threads = 64
@T.prim_func
def hc_split_sinkhorn_kernel_(
mixes: T.Tensor[(n, mix_hc), FP32],
hc_scale: T.Tensor[(3,), FP32],
hc_base: T.Tensor[(mix_hc,), FP32],
pre: T.Tensor[(n, hc), FP32],
post: T.Tensor[(n, hc), FP32],
comb: T.Tensor[(n, hc, hc), FP32],
):
with T.Kernel(n, threads=threads) as i:
mixes_shared = T.alloc_shared(mix_hc, FP32)
comb_frag = T.alloc_fragment((hc, hc), FP32)
T.copy(mixes[i, :], mixes_shared)
for j in T.Parallel(hc):
pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
for j in T.Parallel(hc):
post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
row_sum = T.alloc_fragment(hc, FP32)
col_sum = T.alloc_fragment(hc, FP32)
# comb = comb.softmax(-1) + eps
row_max = T.alloc_fragment(hc, FP32)
T.reduce_max(comb_frag, row_max, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
T.reduce_sum(comb_frag, row_sum, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(comb_frag, col_sum, dim=0)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
for _ in T.serial(sinkhorn_iters - 1):
# comb = comb / (comb.sum(-1) + eps)
T.reduce_sum(comb_frag, row_sum, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(comb_frag, col_sum, dim=0)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
T.copy(comb_frag, comb[i, :, :])
return hc_split_sinkhorn_kernel_
def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
b, s, _ = mixes.size()
pre = mixes.new_empty(b, s, hc_mult)
post = mixes.new_empty(b, s, hc_mult)
comb = mixes.new_empty(b, s, hc_mult, hc_mult)
kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
return pre, post, comb
import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal
from functools import lru_cache
from contextlib import contextmanager
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
# from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
from kernel import sparse_attn, hc_split_sinkhorn
try:
from scipy.linalg import hadamard
except ImportError:
hadamard = None
world_size = 1
rank = 0
block_size = 128
fp4_block_size = 32
default_dtype = torch.bfloat16
scale_fmt = None
scale_dtype = torch.float32
@contextmanager
def set_dtype(dtype):
"""Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
prev = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(prev)
@dataclass
class ModelArgs:
"""Model hyperparameters. Field names match the config JSON keys."""
max_batch_size: int = 4
max_seq_len: int = 4096
dtype: Literal["bf16", "fp8"] = "fp8"
scale_fmt: Literal[None, "ue8m0"] = "ue8m0"
expert_dtype: Literal[None, "fp4", "fp8"] = None
scale_dtype: Literal["fp32", "fp8"] = "fp8"
vocab_size: int = 129280
dim: int = 4096
moe_inter_dim: int = 4096
n_layers: int = 7
n_hash_layers: int = 0
n_mtp_layers: int = 1
n_heads: int = 64
# moe
n_routed_experts: int = 8
n_shared_experts: int = 1
n_activated_experts: int = 2
score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
route_scale: float = 1.
swiglu_limit: float = 0.
# mqa
q_lora_rank: int = 1024
head_dim: int = 512
rope_head_dim: int = 64
norm_eps: float = 1e-6
o_groups: int = 8
o_lora_rank: int = 1024
window_size: int = 128
compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
# yarn
compress_rope_theta: float = 40000.0
original_seq_len: int = 0
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
# index
index_n_heads: int = 64
index_head_dim: int = 128
index_topk: int = 512
# hc
hc_mult: int = 4
hc_sinkhorn_iters: int = 20
hc_eps: float = 1e-6
class ParallelEmbedding(nn.Module):
"""Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
def __init__(self, vocab_size: int, dim: int):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
self.part_vocab_size = (vocab_size // world_size)
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0
dist.all_reduce(y)
return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
For quantized weights, x is first quantized to FP8 via act_quant."""
assert bias is None
return F.linear(x, weight)
class Linear(nn.Module):
"""Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
dtype = dtype or default_dtype
if dtype == torch.float4_e2m1fn_x2:
# FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
# Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
scale_out_features = out_features
scale_in_features = in_features // fp4_block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
elif dtype == torch.float8_e4m3fn:
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
else:
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
self.register_parameter("scale", None)
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return linear(x, self.weight, self.bias)
class ColumnParallelLinear(Linear):
"""Shards output dim across TP ranks. No all-reduce needed on output."""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return linear(x, self.weight, self.bias)
class RowParallelLinear(Linear):
"""Shards input dim across TP ranks. All-reduce on output to sum partial results."""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = linear(x, self.weight, None)
if world_size > 1:
y = y.float()
dist.all_reduce(y)
if self.bias is not None:
y += self.bias
return y.type_as(x)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
dtype = x.dtype
x = x.float()
var = x.square().mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
@lru_cache(2)
def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
"""Precomputes complex exponentials for rotary embeddings with YaRN scaling.
When original_seq_len > 0, applies frequency interpolation with a smooth
linear ramp between beta_fast and beta_slow correction ranges."""
def find_correction_dim(num_rotations, dim, base, max_seq_len):
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim-1)
def linear_ramp_factor(min, max, dim):
if min == max:
max += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if original_seq_len > 0:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth
t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
"""Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
y = x
x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
if inverse:
freqs_cis = freqs_cis.conj()
if x.ndim == 3:
freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
else:
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
x = torch.view_as_real(x * freqs_cis).flatten(-2)
y.copy_(x)
return y
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
if hadamard is None:
raise ImportError("Please install scipy")
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2 ** log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(x, torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device))
out = out * scale
return out[..., :dim].reshape(*x_shape)
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
"""Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
assert x.dtype == torch.bfloat16
# from fast_hadamard_transform import hadamard_transform
return hadamard_transform_ref(x, scale=x.size(-1) ** -0.5)
@lru_cache(1)
def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
if start_pos >= window_size - 1:
start_pos %= window_size
matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
elif start_pos > 0:
matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
else:
base = torch.arange(seqlen).unsqueeze(1)
matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
matrix = torch.where(matrix > base, -1, matrix)
return matrix.unsqueeze(0).expand(bsz, -1, -1)
@lru_cache(2)
def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
if start_pos > 0:
matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
else:
matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
matrix = torch.where(mask, -1, matrix + offset)
return matrix.unsqueeze(0).expand(bsz, -1, -1)
class Compressor(nn.Module):
"""Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
super().__init__()
self.dim = args.dim
self.head_dim = head_dim
self.rope_head_dim = args.rope_head_dim
self.nope_head_dim = head_dim - args.rope_head_dim
self.compress_ratio = compress_ratio
self.overlap = compress_ratio == 4
self.rotate = rotate
coff = 1 + self.overlap
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
# When overlap, the first half of dims is for overlapping compression, second half for normal.
self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
self.norm = RMSNorm(self.head_dim, args.norm_eps)
self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
# State buffers for decode-phase incremental compression.
# With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
self.freqs_cis: torch.Tensor = None
def overlap_transform(self, tensor: torch.Tensor, value=0):
# tensor: [b,s,r,2d]
b, s, _, _ = tensor.size()
ratio, d = self.compress_ratio, self.head_dim
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
return new_tensor
def forward(self, x: torch.Tensor, start_pos: int):
assert self.kv_cache is not None
bsz, seqlen, _ = x.size()
ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
dtype = x.dtype
# compression need fp32
x = x.float()
kv = self.wkv(x)
score = self.wgate(x)
if start_pos == 0:
should_compress = seqlen >= ratio
remainder = seqlen % ratio
cutoff = seqlen - remainder
offset = ratio if overlap else 0
if overlap and cutoff >= ratio:
self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
if remainder > 0:
kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
score = score[:, :cutoff]
kv = kv.unflatten(1, (-1, ratio))
score = score.unflatten(1, (-1, ratio)) + self.ape
if overlap:
kv = self.overlap_transform(kv, 0)
score = self.overlap_transform(score, float("-inf"))
kv = (kv * score.softmax(dim=2)).sum(dim=2)
else:
should_compress = (start_pos + 1) % self.compress_ratio == 0
score += self.ape[start_pos % ratio]
if overlap:
self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
if should_compress:
kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
else:
self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
if should_compress:
kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
if not should_compress:
return
kv = self.norm(kv.to(dtype))
if start_pos == 0:
freqs_cis = self.freqs_cis[:cutoff:ratio]
else:
freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
apply_rotary_emb(kv[..., -rd:], freqs_cis)
if self.rotate:
kv = rotate_activation(kv)
# fp4_act_quant(kv, fp4_block_size, True)
# else:
# act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
if start_pos == 0:
self.kv_cache[:bsz, :seqlen // ratio] = kv
else:
self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
return kv
class Indexer(torch.nn.Module):
"""Selects top-k compressed KV positions for sparse attention via learned scoring.
Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
def __init__(self, args: ModelArgs, compress_ratio: int = 4):
super().__init__()
self.dim = args.dim
self.n_heads = args.index_n_heads
self.n_local_heads = args.index_n_heads // world_size
self.head_dim = args.index_head_dim
self.rope_head_dim = args.rope_head_dim
self.index_topk = args.index_topk
self.q_lora_rank = args.q_lora_rank
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim, dtype=torch.bfloat16)
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
self.softmax_scale = self.head_dim ** -0.5
self.compress_ratio = compress_ratio
self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
self.freqs_cis = None
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
bsz, seqlen, _ = x.size()
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
ratio = self.compress_ratio
rd = self.rope_head_dim
end_pos = start_pos + seqlen
if self.compressor.kv_cache is None:
self.compressor.kv_cache = self.kv_cache
self.compressor.freqs_cis = self.freqs_cis
q = self.wq_b(qr)
q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
apply_rotary_emb(q[..., -rd:], freqs_cis)
q = rotate_activation(q)
# use fp4 simulation for q and kv in indexer
# fp4_act_quant(q, fp4_block_size, True)
self.compressor(x, start_pos)
weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
if world_size > 1:
dist.all_reduce(index_score)
if start_pos == 0:
mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
index_score += torch.where(mask, float("-inf"), 0)
topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
if start_pos == 0:
mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
topk_idxs = torch.where(mask, -1, topk_idxs + offset)
else:
topk_idxs += offset
return topk_idxs
class Attention(nn.Module):
"""Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
self.q_lora_rank = args.q_lora_rank
self.o_lora_rank = args.o_lora_rank
self.head_dim = args.head_dim
self.rope_head_dim = args.rope_head_dim
self.nope_head_dim = args.head_dim - args.rope_head_dim
self.n_groups = args.o_groups
self.n_local_groups = self.n_groups // world_size
self.window_size = args.window_size
self.compress_ratio = args.compress_ratios[layer_id]
self.eps = args.norm_eps
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
self.wq_a = Linear(self.dim, self.q_lora_rank, dtype=torch.bfloat16)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim, dtype=torch.bfloat16)
self.wkv = Linear(self.dim, self.head_dim, dtype=torch.bfloat16)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
self.softmax_scale = self.head_dim ** -0.5
if self.compress_ratio:
self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
if self.compress_ratio == 4:
self.indexer = Indexer(args, self.compress_ratio)
else:
self.indexer = None
kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
if self.compress_ratio:
original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
else:
# disable YaRN and use base rope_theta in pure sliding-window attention
original_seq_len, rope_theta = 0, args.rope_theta
freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: torch.Tensor, start_pos: int):
bsz, seqlen, _ = x.size()
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
win = self.window_size
ratio = self.compress_ratio
rd = self.rope_head_dim
if self.compress_ratio and self.compressor.kv_cache is None:
self.compressor.kv_cache = self.kv_cache[:, win:]
self.compressor.freqs_cis = self.freqs_cis
if self.indexer is not None:
self.indexer.freqs_cis = self.freqs_cis
# q
qr = q = self.q_norm(self.wq_a(x))
q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
apply_rotary_emb(q[..., -rd:], freqs_cis)
# win kv & topk_idxs
kv = self.wkv(x)
kv = self.kv_norm(kv)
apply_rotary_emb(kv[..., -rd:], freqs_cis)
# FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
# act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
if self.compress_ratio:
offset = kv.size(1) if start_pos == 0 else win
if self.indexer is not None:
compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
else:
compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
topk_idxs = topk_idxs.int()
# compress kv & attn
if start_pos == 0:
if seqlen <= win:
self.kv_cache[:bsz, :seqlen] = kv
else:
cutoff = seqlen % win
self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
if self.compress_ratio:
if (kv_compress := self.compressor(x, start_pos)) is not None:
kv = torch.cat([kv, kv_compress], dim=1)
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
else:
self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
if self.compress_ratio:
self.compressor(x, start_pos)
o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
apply_rotary_emb(o[..., -rd:], freqs_cis, True)
# o
o = o.view(bsz, seqlen, self.n_local_groups, -1)
wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
# NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf,
# but using BF16 for simplicity.
o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
x = self.wo_b(o.flatten(2))
return x
class Gate(nn.Module):
"""MoE gating: computes expert routing scores and selects top-k experts.
Supports hash-based routing (first n_hash_layers) where expert indices are
predetermined per token ID, and score-based routing (remaining layers)."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.topk = args.n_activated_experts
self.score_func = args.score_func
self.route_scale = args.route_scale
self.hash = layer_id < args.n_hash_layers
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
if self.hash:
self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
self.bias = None
else:
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
scores = linear(x.float(), self.weight.float())
if self.score_func == "softmax":
scores = scores.softmax(dim=-1)
elif self.score_func == "sigmoid":
scores = scores.sigmoid()
else:
scores = F.softplus(scores).sqrt()
original_scores = scores
# Bias shifts scores for expert selection (topk) but does not affect routing weights.
if self.bias is not None:
scores = scores + self.bias
if self.hash:
indices = self.tid2eid[input_ids]
else:
indices = scores.topk(self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func != "softmax":
weights /= weights.sum(dim=-1, keepdim=True)
weights *= self.route_scale
return weights, indices
class Expert(nn.Module):
"""Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
super().__init__()
self.w1 = Linear(dim, inter_dim, dtype=dtype)
self.w2 = Linear(inter_dim, dim, dtype=dtype)
self.w3 = Linear(dim, inter_dim, dtype=dtype)
self.swiglu_limit = swiglu_limit
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
dtype = x.dtype
gate = self.w1(x).float()
up = self.w3(x).float()
if self.swiglu_limit > 0:
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
gate = torch.clamp(gate, max=self.swiglu_limit)
x = F.silu(gate) * up
if weights is not None:
x = weights * x
return self.w2(x.to(dtype))
class MoE(nn.Module):
"""Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id
self.dim = args.dim
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(layer_id, args)
if args.expert_dtype == "fp4":
expert_dtype = torch.float4_e2m1fn_x2
elif args.expert_dtype == "fp8":
expert_dtype = torch.float8_e4m3fn
else:
None
# expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=torch.bfloat16, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)])
assert args.n_shared_experts == 1
# no swiglu_limit
self.shared_experts = Expert(args.dim, args.moe_inter_dim, dtype=torch.bfloat16)
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x, input_ids.flatten())
y = torch.zeros_like(x, dtype=torch.float32)
####
torch.cuda.synchronize()
indices_cpu = indices.flatten().cpu()
counts_cpu = torch.bincount(indices_cpu, minlength=self.n_routed_experts)
counts = counts_cpu.cuda()
#counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
torch.cuda.synchronize()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx], weights[idx, top, None])
if world_size > 1:
dist.all_reduce(y)
y += self.shared_experts(x)
return y.type_as(x).view(shape)
class Block(nn.Module):
"""Transformer block with Hyper-Connections (HC) mixing.
Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.layer_id = layer_id
self.norm_eps = args.norm_eps
self.attn = Attention(layer_id, args)
self.ffn = MoE(layer_id, args)
self.attn_norm = RMSNorm(args.dim, self.norm_eps)
self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
self.hc_mult = hc_mult = args.hc_mult
self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
self.hc_eps = args.hc_eps
mix_hc = (2 + hc_mult) * hc_mult
hc_dim = hc_mult * args.dim
with set_dtype(torch.float32):
self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
self.hc_attn_scale = nn.Parameter(torch.empty(3))
self.hc_ffn_scale = nn.Parameter(torch.empty(3))
def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
# x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
shape, dtype = x.size(), x.dtype
x = x.flatten(2).float()
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
mixes = F.linear(x, hc_fn) * rsqrt
pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
return y.to(dtype), post, comb
def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
# x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
return y.type_as(x)
def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
residual = x
x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
x = self.attn_norm(x)
x = self.attn(x, start_pos)
x = self.hc_post(x, residual, post, comb)
residual = x
x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
x = self.ffn_norm(x)
x = self.ffn(x, input_ids)
x = self.hc_post(x, residual, post, comb)
return x
class ParallelHead(nn.Module):
def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.norm_eps = norm_eps
self.hc_eps = hc_eps
self.part_vocab_size = (vocab_size // world_size)
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
def get_logits(self, x):
return F.linear(x[:, -1].float(), self.weight)
def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
# x: [b,s,hc,d]
x = self.hc_head(x, hc_fn, hc_scale, hc_base)
logits = self.get_logits(norm(x))
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
return logits
def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
shape, dtype = x.size(), x.dtype
x = x.flatten(2).float()
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
mixes = F.linear(x, hc_fn) * rsqrt
pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
return y.to(dtype)
class MTPBlock(Block):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__(layer_id, args)
self.e_proj = Linear(args.dim, args.dim, dtype=torch.bfloat16)
self.h_proj = Linear(args.dim, args.dim, dtype=torch.bfloat16)
self.enorm = RMSNorm(args.dim, args.norm_eps)
self.hnorm = RMSNorm(args.dim, args.norm_eps)
self.norm = RMSNorm(args.dim, args.norm_eps)
self.hc_mult = hc_mult = args.hc_mult
hc_dim = hc_mult * args.dim
with set_dtype(torch.float32):
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
self.hc_head_scale = nn.Parameter(torch.empty(1))
self.embed: ParallelEmbedding = None
self.head: ParallelHead = None
@torch.inference_mode()
def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
# x: [b,s,hc,d]
assert self.embed is not None and self.head is not None
e = self.embed(input_ids)
e = self.enorm(e)
x = self.hnorm(x)
x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
x = super().forward(x, start_pos, input_ids)
logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
return logits
class Transformer(nn.Module):
"""Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
def __init__(self, args: ModelArgs):
# global world_size, rank, default_dtype, scale_fmt, scale_dtype
global world_size, rank
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
super().__init__()
self.max_seq_len = args.max_seq_len
self.norm_eps = args.norm_eps
self.hc_eps = args.hc_eps
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim, self.norm_eps)
self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
self.mtp = torch.nn.ModuleList()
for layer_id in range(args.n_mtp_layers):
self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
self.mtp[-1].embed = self.embed
self.mtp[-1].head = self.head
self.hc_mult = hc_mult = args.hc_mult
hc_dim = hc_mult * args.dim
with set_dtype(torch.float32):
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
self.hc_head_scale = nn.Parameter(torch.empty(1))
@torch.inference_mode()
def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
h = self.embed(input_ids)
# Expand to hc_mult copies for Hyper-Connections
h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
for layer in self.layers:
h = layer(h, start_pos, input_ids)
logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
return logits
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.manual_seed(0)
args = ModelArgs(n_hash_layers=0)
x = torch.randint(0, args.vocab_size, (2, 128))
model = Transformer(args)
print(model(x).size())
for i in range(128, 150):
print(i, model(x[:, 0:1], i).size())
h = torch.randn(2, 128, args.hc_mult, args.dim)
mtp = model.mtp[0]
print(mtp(h, 0, x).size())
print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
#!/usr/bin/env bash
set -euo pipefail
export NCCL_ALGO="Ring"
export NCCL_PROTO="Simple"
export MP=8
export CONFIG="config.json"
export CKPT_PATH="deepseek-ai/DeepSeek-V4-Flash-bf16-mp8"
torchrun \
--nproc-per-node "${MP}" \
generate.py \
--ckpt-path "${CKPT_PATH}" \
--config "${CONFIG}" \
--interactive
# 模型唯一标识
modelCode=2397
# 模型名称
modelName=DeepSeek-V4
# 模型描述
modelDescription= DeepSeek-V4:迈向高效百万上下文智能。
# 运行过程
processType=推理
# 算法类别
appCategory=对话问答
# 框架类型
frameType=pytorch
# 加速卡类型
accelerateType=BW1100
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment