"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "7123d831328321e854b78047effe7a57192a764f"
Commit 5e61101f authored by luopl's avatar luopl
Browse files

Initial commit

parents
MIT License
Copyright (c) 2023 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
## DeepSeek-V4
## 论文
[DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/DeepSeek_V4.pdf)
## 模型简介
DeepSeek-V4 系列的预览版本,包含两款强大的混合专家(Mixture-of-Experts, MoE)语言模型:DeepSeek-V4-Pro(总参数量 1.6T,激活参数量 49B)和 DeepSeek-V4-Flash(总参数量 284B,激活参数量 13B),两者均支持 百万 token 的上下文长度。
DeepSeek-V4 系列在架构与优化方面引入了多项关键升级:
- 混合注意力架构:设计了一种混合注意力机制,结合压缩稀疏注意力(Compressed Sparse Attention, CSA)与重度压缩注意力(Heavily Compressed Attention, HCA),显著提升长上下文处理效率。在百万 token 上下文场景下,DeepSeek-V4-Pro 相比 DeepSeek-V3.2 仅需 27% 的单 token 推理 FLOPs 和 10% 的 KV 缓存。
- 流形约束超连接(Manifold-Constrained Hyper-Connections, mHC):在传统残差连接基础上引入 mHC,增强跨层信号传播的稳定性,同时保留模型的表达能力。
- Muon 优化器:采用 Muon 优化器,以实现更快的收敛速度和更高的训练稳定性。
## 环境依赖
| 软件 | 版本 |
| :------: |:-------:|
| DTK | 26.04 |
| python | 3.10.12 |
| torch | 2.9.0+das.opt1.dtk2604.20260331.g4e3c1e7 |
| tilelang | 0.1.7.post3+cpu.git52700923 |
推荐使用镜像:harbor.sourcefind.cn:5443/dcu/admin/base/custom:torch-2.9.0-ubuntu22.04-dtk26.04-deepseek-v4-0425
- 挂载地址`-v`根据实际模型情况修改
```bash
docker run -it \
--shm-size 200g \
--network=host \
--name deepseek-v4 \
--privileged \
--device=/dev/kfd \
--device=/dev/dri \
--device=/dev/mkfd \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-u root \
-v /opt/hyhal/:/opt/hyhal/:ro \
-v /path/your_code_data/:/path/your_code_data/ \
harbor.sourcefind.cn:5443/dcu/admin/base/custom:torch-2.9.0-ubuntu22.04-dtk26.04-deepseek-v4-0425 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
## 数据集
`暂无`
## 训练
`暂无`
## 推理
### pytorch
#### 单机推理
1. 模型转换与切分
```bash
#注意将脚本中对应的路径及参数设置成用户实际值
#其中:INPUT_FP8_HF_PATH为模型下载路径;OUTPUT_BF16_HF_PATH为bf16模型存放路径;SAVE_PATH为切分好的模型路径;mp根据实际卡数调整
cd convert_weight
bash convert_weight.sh
```
2. 启动对话推理
```bash
cd ../inference
#注意将脚本中对应的路径及参数设置成用户实际值
sh start_torch.sh
```
## 效果展示
<div align=center>
<img src="./doc/result_dcu.png"/>
</div>
### 精度
`DCU与GPU精度一致,推理框架:pytorch。`
## 预训练权重
| 模型名称 | 权重大小 | DCU型号 | 最低卡数需求 |下载地址|
|:-----:|:----:|:------:|:------:|:----------:|
| DeepSeek-V4-Flash | 158B | BW1100 | 8 | [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) |
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/deepseek-v4
## 参考资料
- https://github.com/deepseek-ai
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
FP4_TABLE = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
], dtype=torch.float32)
def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Casts a tensor from e2m1fn to e4m3fn losslessly.
"""
assert x.dtype == torch.int8
assert x.ndim == 2
out_dim, in_dim = x.size()
in_dim *= 2
fp8_block_size = 128
fp4_block_size = 32
assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
x = x.view(torch.uint8)
low = x & 0x0F
high = (x >> 4) & 0x0F
x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
# max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
# 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
MAX_OFFSET_BITS = 6
bOut = out_dim // fp8_block_size
bIn = in_dim // fp8_block_size
# bOut, bIn, 128, 128
x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
# bOut, bIn, 128*4
scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
## bOut, bIn, 1
scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
# bOut, bIn, 128*4
offset = scale / scale_max_offset_bits
# bOut, bIn, 128, 128
offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"lm_head": ("head", 0),
"embed": ("embed", 0),
"wq_b": ("wq_b", 0),
"wo_a": ("wo_a", 0),
"wo_b": ("wo_b", 1),
"head": ("head", 0),
"attn_sink": ("attn_sink", 0),
"weights_proj": ("weights_proj", 0),
}
def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
"""
Converts and saves model checkpoint files into a specified format.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns:
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
continue
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
key = name.split(".")[-1]
else:
key = name.split(".")[-2]
if key in mapping:
new_key, dim = mapping[key]
else:
new_key, dim = key, None
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
changed=True
if not changed:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
else:
print(f"Processing parameter {name} with shape {param.shape} for model parallel shard {i}")
if "wo_a" not in name and "wo_b" not in name:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
else:
num_projection_groups = mp//8
new_mp = mp // num_projection_groups
new_i = i // num_projection_groups
shard_size = param.size(dim) // new_mp
assert(shard_size==1024)
new_param = param.narrow(dim, new_i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
names = list(state_dicts[i].keys())
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for file in ["tokenizer.json", "tokenizer_config.json"]:
old_file_path = os.path.join(hf_ckpt_path, file)
new_file_path = os.path.join(save_path, file)
shutil.copyfile(old_file_path, new_file_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
"""
DeepSeek-V3.2 FP8/FP4 -> BF16 Converter
Key differences from V3:
- Scale format: ue8m0 (unsigned E8M0, power-of-2 scales, may be stored as uint8)
- weight_dequant uses block-reshape approach (from V3.2 model.py:490)
- DSA (DeepSeek Attention): each layer has `indexer` submodule
- indexer.wq_b, indexer.wk: FP8 (have weight_scale_inv)
- indexer.k_norm (weight+bias), indexer.weights_proj: BF16 (no scale)
- MTP layer (layer 61): enorm, hnorm, eh_proj, shared_head.norm/head, embed_tokens
- No dependency on kernel.py (pure PyTorch)
- 62 layers (0-61), 163 shards, ~92k keys
- Experts weights use FP4 (MXFP4 E2M1, packed 2 per uint8, group_size=32, E8M0 scale)
"""
import os
import json
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
BLOCK_SIZE = 128
FP4_GROUP_SIZE = 32
# E2M1 FP4 lookup table: 4-bit index -> float value
# Bit layout: sign(1) | exponent(2) | mantissa(1), bias=1
_FP4_E2M1_LUT = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
], dtype=torch.bfloat16)
def dequant_fp4_weight(weight_packed: torch.Tensor, scale_fp8: torch.Tensor) -> torch.Tensor:
"""
Dequantize MXFP4 weight to bf16 by unpacking FP4 E2M1 nibbles via LUT.
Each int8 byte stores 2 FP4 values (low nibble + high nibble).
float4_e2m1fn_x2 does NOT support .to(bfloat16), so we manually unpack.
Scale is E8M0, one per group of 32 elements along the last dimension.
Args:
weight_packed: [out_features, in_features/2], int8, each byte = 2 FP4 values
scale_fp8: [out_features, in_features/32], float8_e8m0fnu, E8M0 scale per group of 32
Returns:
bf16 tensor [out_features, in_features]
"""
out_features, packed_in = weight_packed.shape
in_features = packed_in * 2
# Unpack two FP4 values per byte via nibble extraction + LUT
raw = weight_packed.to(torch.uint8)
low_nibble = (raw & 0x0F).to(torch.long)
high_nibble = (raw >> 4).to(torch.long)
lut = _FP4_E2M1_LUT.to(weight_packed.device)
low_vals = lut[low_nibble]
high_vals = lut[high_nibble]
# Interleave: [low_0, high_0, low_1, high_1, ...]
fp4_values = torch.stack([low_vals, high_vals], dim=-1).reshape(out_features, in_features)
# Decode E8M0 scale and expand to match fp4_values
scale = decode_e8m0_scale(scale_fp8)
if scale.dim() == 2 and scale.shape[0] == out_features:
# Scale already shaped [out_features, num_groups_per_row]
num_groups_per_row = scale.shape[1]
else:
# Flat scale: compute groups per row from total count
total_scales = scale.numel()
num_groups_per_row = total_scales // out_features
scale = scale.reshape(out_features, num_groups_per_row)
actual_group_size = in_features // num_groups_per_row
scale = scale.unsqueeze(-1).expand(-1, -1, actual_group_size).reshape(out_features, in_features)
return (fp4_values * scale).to(torch.bfloat16)
def is_expert_weight(name: str) -> bool:
"""Check if a weight belongs to an expert (MoE) layer, excluding shared_experts."""
return "experts" in name and "shared_experts" not in name
def decode_e8m0_scale(scale: torch.Tensor) -> torch.Tensor:
"""
Decode E8M0 (unsigned 8-bit exponent-only) scale to float32.
E8M0 stores only the exponent: value = 2^(exp - 127), same as IEEE 754 exponent encoding.
If scale is already float32, return as-is.
"""
if scale.dtype == torch.float32:
return scale
if scale.dtype in (torch.bfloat16, torch.float16):
return scale.float()
# float8_e8m0fnu: .to(int32) does value conversion not bit reinterpret,
# so use native .float() which handles E8M0 correctly
if scale.dtype == torch.float8_e8m0fnu:
return scale.float()
# uint8 / int8 raw E8M0 bytes: interpret as IEEE 754 exponent
if scale.element_size() == 1:
# Reconstruct float32 from exponent: set exponent bits in IEEE 754 float32
# float32 = sign(1) + exponent(8) + mantissa(23)
# E8M0 value stored as raw exponent byte -> float = 2^(byte - 127)
exp_bits = scale.to(torch.int32) << 23
return exp_bits.view(torch.float32)
return scale.float()
def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = BLOCK_SIZE) -> torch.Tensor:
"""
Dequantize FP8 weight to BF16 using block-wise scale.
Based on V3.2 model.py:490-495, works on both CPU and CUDA.
Each (block_size x block_size) block of the weight is multiplied by one scale value.
Handles non-aligned dimensions (e.g. kv_a_proj_with_mqa shape 576x7168, 576 % 128 != 0)
by padding to the nearest multiple of block_size, dequantizing, then trimming.
"""
shape = weight.shape
assert weight.dim() == 2, f"Expected 2D weight, got {weight.dim()}D"
M, N = shape
# Decode E8M0 scale if needed
scale = decode_e8m0_scale(scale)
# Pad to nearest multiple of block_size if needed
pad_m = (block_size - M % block_size) % block_size
pad_n = (block_size - N % block_size) % block_size
if pad_m or pad_n:
weight = torch.nn.functional.pad(weight, (0, pad_n, 0, pad_m))
Mp, Np = weight.shape
# V3.2 dequant: reshape into blocks, scale, reshape back
weight = weight.view(
Mp // block_size, block_size,
Np // block_size, block_size
).transpose(1, 2).contiguous().view(-1, block_size * block_size)
weight = (weight.float() * scale.reshape(-1, 1)).to(torch.bfloat16)
weight = weight.view(
Mp // block_size, Np // block_size,
block_size, block_size
).transpose(1, 2).contiguous().view(Mp, Np)
# Trim padding
if pad_m or pad_n:
weight = weight[:M, :N]
return weight
def main(fp8_path, bf16_path, device="cuda"):
torch.set_default_dtype(torch.bfloat16)
if device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
device = "cpu"
os.makedirs(bf16_path, exist_ok=True)
# 1. Copy non-safetensor files (config.json, tokenizer, etc.)
print("Copying auxiliary files...")
for file_path in glob(os.path.join(fp8_path, "*")):
fname = os.path.basename(file_path)
if fname.endswith(".safetensors") or fname == "model.safetensors.index.json":
continue
dst = os.path.join(bf16_path, fname)
if os.path.isfile(file_path):
shutil.copy2(file_path, dst)
print(f" Copied {fname}")
# 2. Load model index
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# 3. Pre-build scale_inv lookup: weight_name -> scale_inv_name
# V3.2 naming: "xxx.weight" -> "xxx.weight_scale_inv"
scale_inv_map = {}
all_scale_names = set()
for name in weight_map:
if name.endswith("scale"):
weight_name = name[:-len("scale")] + "weight"
if weight_name in weight_map:
all_scale_names.add(name)
scale_inv_map[weight_name] = name
# Separate expert (FP4) vs non-expert (FP8) scale mappings
fp4_scale_map = {k: v for k, v in scale_inv_map.items() if is_expert_weight(k)}
fp8_scale_map = {k: v for k, v in scale_inv_map.items() if not is_expert_weight(k)}
print(f"Model: DeepSeek (DeepseekForCausalLM)")
print(f"Device: {device}")
print(f"Total keys in index: {len(weight_map)}")
print(f"FP4 expert weights with scale: {len(fp4_scale_map)}")
print(f"FP8 non-expert weights with scale: {len(fp8_scale_map)}")
print(f"Scale entries: {len(all_scale_names)}")
# Cache for loaded safetensor files on CPU (handles cross-shard scale lookups)
# All shards are loaded to CPU to avoid GPU OOM; individual tensors are moved to
# GPU for dequant one at a time.
loaded_files = {}
use_cuda = (device == "cuda")
def get_tensor(tensor_name):
"""Load tensor from the correct shard file (CPU), with caching."""
file_name = weight_map.get(tensor_name)
if file_name is None:
raise KeyError(tensor_name)
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cpu")
return loaded_files[file_name][tensor_name]
# 4. Process safetensor files one by one
safetensor_files = sorted(glob(os.path.join(fp8_path, "*.safetensors")))
converted_count = 0
kept_count = 0
for safetensor_file in tqdm(safetensor_files, desc="Converting FP8 -> BF16"):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cpu")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
# Skip scale_inv tensors (will be removed from output)
if weight_name in all_scale_names:
continue
if weight.element_size() == 1 and weight_name in scale_inv_map:
scale_inv_name = scale_inv_map[weight_name]
try:
scale_inv = get_tensor(scale_inv_name)
if is_expert_weight(weight_name):
# FP4 expert weight -> dequantize using MXFP4 logic
# print(f" FP4 dequant: {weight_name}, weight={weight.shape}, scale={scale_inv.shape}, dtype={scale_inv.dtype}")
if use_cuda:
result = dequant_fp4_weight(weight.cuda(), scale_inv.cuda())
new_state_dict[weight_name] = result.cpu()
else:
new_state_dict[weight_name] = dequant_fp4_weight(weight, scale_inv)
else:
# FP8 non-expert weight -> dequantize using block-wise scale
if use_cuda:
result = weight_dequant(weight.cuda(), scale_inv.cuda())
new_state_dict[weight_name] = result.cpu()
else:
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
converted_count += 1
except KeyError:
print(f" Warning: scale '{scale_inv_name}' not loadable for {weight_name}, keeping raw")
new_state_dict[weight_name] = weight
else:
# BF16/FP32 weights: norms, biases, gate.weight, indexer.weights_proj,
# indexer.k_norm, MTP layers (enorm, hnorm, eh_proj, shared_head), embed, lm_head, etc.
new_state_dict[weight_name] = weight
kept_count += 1
# Save converted shard
save_file(new_state_dict, os.path.join(bf16_path, file_name))
# Memory management: keep at most 2 cached shard files on CPU
# (needed for cross-shard scale lookups, e.g. weight in shard N, scale in shard N+1)
while len(loaded_files) > 2:
oldest = next(iter(loaded_files))
del loaded_files[oldest]
# 5. Update model index: remove all scale_inv entries
new_weight_map = {k: v for k, v in weight_map.items() if k not in all_scale_names}
new_index = {
"metadata": model_index.get("metadata", {}),
"weight_map": new_weight_map,
}
with open(os.path.join(bf16_path, "model.safetensors.index.json"), "w") as f:
json.dump(new_index, f, indent=2)
print(f"\nDone!")
print(f" FP8/FP4 -> BF16 converted: {converted_count}")
print(f" Already BF16/FP32 (kept as-is): {kept_count}")
print(f" Scale entries removed: {len(all_scale_names)}")
print(f" Output keys: {len(new_weight_map)} (was {len(weight_map)})")
print(f" Output saved to: {bf16_path}")
if __name__ == "__main__":
parser = ArgumentParser(description="Convert DeepSeek-V3.2 FP8/FP4 checkpoint to BF16")
parser.add_argument("--input-fp8-hf-path", type=str, required=True,
help="Path to the FP8 HuggingFace model directory (DeepSeek-V3.2)")
parser.add_argument("--output-bf16-hf-path", type=str, required=True,
help="Path to the output BF16 model directory")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
help="Device for dequantization (default: cuda)")
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.device)
#!/usr/bin/env bash
set -euo pipefail
# Step 1: fp4/fp8 -> bf16
INPUT_FP8_HF_PATH="deepseek-ai/DeepSeek-V4-Flash"
OUTPUT_BF16_HF_PATH="deepseek-ai/DeepSeek-V4-Flash-bf16"
python3 convert_weight.py \
--input-fp8-hf-path "${INPUT_FP8_HF_PATH}" \
--output-bf16-hf-path "${OUTPUT_BF16_HF_PATH}"
# Step 2: bf16 -> bf16-mp16
MP=8
HF_CKPT_BF16_PATH="${OUTPUT_BF16_HF_PATH}"
SAVE_PATH="deepseek-ai/DeepSeek-V4-Flash-bf16-mp16"
EXPERTS=256
python3 convert.py \
--hf-ckpt-path "${HF_CKPT_BF16_PATH}" \
--save-path "${SAVE_PATH}" \
--n-experts "${EXPERTS}" \
--model-parallel "${MP}"
icon.png

53.8 KB

# Inference code for DeepSeek models
First convert huggingface model weight files to the format of this project.
```bash
export EXPERTS=256
export MP=4
export CONFIG=config.json
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
```
Then chat with DeepSeek model at will!
```bash
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
```
Or batch inference from file.
```bash
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
```
Or multi nodes inference.
```bash
torchrun --nnodes ${NODES} --nproc-per-node $((MP / NODES)) --node-rank $RANK --master-addr $ADDR generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
```
If you want to use fp8, just remove `"expert_dtype": "fp4"` in `config.json` and specify `--expert-dtype fp8` in `convert.py`.
{
"vocab_size": 129280,
"dim": 4096,
"moe_inter_dim": 2048,
"n_layers": 43,
"n_hash_layers": 3,
"n_heads": 64,
"n_routed_experts": 256,
"n_shared_experts": 1,
"n_activated_experts": 6,
"score_func": "sqrtsoftplus",
"route_scale": 1.5,
"swiglu_limit": 10.0,
"q_lora_rank": 1024,
"head_dim": 512,
"rope_head_dim": 64,
"o_groups": 8,
"o_lora_rank": 1024,
"window_size": 128,
"original_seq_len": 65536,
"rope_theta": 10000,
"rope_factor": 16,
"beta_fast": 32,
"beta_slow": 1,
"index_n_heads": 64,
"index_head_dim": 128,
"index_topk": 512,
"hc_mult": 4,
"hc_sinkhorn_iters": 20,
"dtype": "fp8",
"scale_fmt": "ue8m0",
"compress_rope_theta": 160000,
"compress_ratios": [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
}
\ No newline at end of file
This diff is collapsed.
import os
import json
import sys
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
current_dir = os.path.dirname(os.path.abspath(__file__))
encoding_dir = os.path.join(current_dir, '../encoding')
sys.path.insert(0, os.path.abspath(encoding_dir))
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
def sample(logits, temperature: float = 1.0):
"""Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
since it avoids the GPU-to-CPU sync in torch.multinomial."""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""Batch generation with left-padded prompts.
The first forward pass processes [min_prompt_len:] tokens (prefill phase).
Subsequent passes generate one token at a time (decode phase). For positions
still within a prompt, the ground-truth token overrides the model's prediction.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens))
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
toks.append(eos_id)
completion_tokens.append(toks)
return completion_tokens
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(33377335)
with open(config) as f:
args = ModelArgs(**json.load(f))
if interactive:
args.max_batch_size = 1
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print("load model")
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
torch.set_default_device("cuda")
print("I'm DeepSeek 👋")
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0])
print(completion)
messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
else:
with open(input_file) as f:
prompts = f.read().split("\n\n")
prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=300)
parser.add_argument("--temperature", type=float, default=0.6)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
import torch
import tilelang
import tilelang.language as T
from typing import Tuple, Optional
tilelang.set_log_level("WARNING")
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
}
FP8 = "float8_e4m3"
FP4 = "float4_e2m1fn"
# FE8M0 = "float8_e8m0fnu"
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
def fast_log2_ceil(x):
"""Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
"""Compute 2^x for integer x via IEEE 754 bit manipulation."""
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
round_scale=False, inplace=False
):
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale or inplace else 2
blk_m = 32
group_size = block_size
# Internal computation in FP32; scale_dtype controls output storage format.
compute_dtype = FP32
out_dtype = in_dtype if inplace else out_dtype
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
s_local = T.alloc_fragment((blk_m,), compute_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
if inplace:
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.Cast(
out_dtype,
T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
))) * s_local[i],
)
else:
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
) -> torch.Tensor:
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
N = x.size(-1)
assert N % block_size == 0
# tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
tl_dtype = FP32
z = x.contiguous()
y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
kernel = act_quant_kernel(
N, block_size, scale_dtype=tl_dtype,
round_scale=scale_fmt is not None, inplace=inplace,
)
kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
if inplace:
x.copy_(y)
return x
return y, s
@tilelang.jit(pass_configs=pass_configs)
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
assert out_dtype in [BF16, FP32]
M = T.symbolic("M")
group_size = 128
block_M = 32
block_N = 128
block_K = 128
@T.prim_func
def fp8_gemm_kernel_(
A: T.Tensor[(M, K), FP8],
B: T.Tensor[(N, K), FP8],
C: T.Tensor[(M, N), out_dtype],
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), FP8)
B_shared = T.alloc_shared((block_N, block_K), FP8)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
Scale_C_shared = T.alloc_shared((block_M), FP32)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=2):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
# Cast scales to FP32 for computation; scales_b has one value per block_N group
Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
for i in T.Parallel(block_M):
Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Separate accumulator for scale-corrected results (2x accumulation precision)
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return fp8_gemm_kernel_
def fp8_gemm(
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
scale_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
assert a_s.is_contiguous() and b_s.is_contiguous(), (
"Scaling factor tensors must be contiguous"
)
# tl_dtype = FP32
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
return c
@tilelang.jit(pass_configs=pass_configs)
def sparse_attn_kernel(h_orig: int, d: int, scale=None):
"""Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
with numerically stable running max/sum, and includes a learnable attn_sink bias."""
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
topk = T.symbolic("topk")
if scale is None:
scale = (1.0 / d) ** 0.5
num_stages = 0
threads = 256
block = 32
num_blocks = tilelang.cdiv(topk, block)
padded_H = max(tilelang.math.next_power_of_2(h_orig), 16)
max_block_m = 16
if h_orig > max_block_m:
assert h_orig % max_block_m == 0, f"h should be a multiple of {max_block_m}"
REPLICATE_H = h_orig // max_block_m
else:
REPLICATE_H = 1
h = padded_H if REPLICATE_H == 1 else max_block_m
@T.prim_func
def sparse_attn_kernel_(
q: T.Tensor[(b, m, h_orig, d), BF16],
kv: T.Tensor[(b, n, d), BF16],
o: T.Tensor[(b, m, h_orig, d), BF16],
attn_sink: T.Tensor[(h_orig,), FP32],
topk_idxs: T.Tensor[(b, m, topk), INT32],
):
with T.Kernel(m * REPLICATE_H, b, threads=threads) as (bx, by):
q_shared = T.alloc_fragment((h, d), BF16)
kv_shared = T.alloc_shared((block, d), BF16)
# o_shared = T.alloc_shared((h, d), BF16)
acc_s_cast = T.alloc_shared((h, block), BF16)
idxs = T.alloc_fragment(block, INT32)
acc_s = T.alloc_fragment((h, block), FP32)
acc_o = T.alloc_fragment((h, d), FP32)
scores_max = T.alloc_fragment(h, FP32)
scores_max_prev = T.alloc_fragment(h, FP32)
scores_scale = T.alloc_fragment(h, FP32)
scores_sum = T.alloc_fragment(h, FP32)
sum_exp = T.alloc_fragment(h, FP32)
T.clear(acc_o)
T.clear(sum_exp)
T.fill(scores_max, -T.infinity(FP32))
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
H0 = (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * h)
H1 = H0 + h
T.copy(q[by, s_i, H0:H1, :], q_shared)
for t in T.Pipelined(num_blocks, num_stages=num_stages):
for i in T.Parallel(block):
idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, s_i, t * block + i], -1)
for i, j in T.Parallel(block, d):
kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
for i, j in T.Parallel(h, block):
acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(h, block):
acc_s[i, j] *= scale
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(h):
scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
for i, j in T.Parallel(h, block):
acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(h):
sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(h, d):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(h):
sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
for i, j in T.Parallel(h, d):
acc_o[i, j] /= sum_exp[i]
o_shared = T.alloc_shared((h, d), BF16)
T.copy(acc_o, o_shared)
T.copy(o_shared, o[by, s_i, H0:H1, :])
return sparse_attn_kernel_
def sparse_attn(
q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
) -> torch.Tensor:
b, s, h, d = q.size()
# print(f"Teng {q.size()=}")
# Pad heads to 16 for kernel efficiency (stripped after)
if h < 16:
q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
o = torch.empty_like(q)
kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
kernel(q, kv, o, attn_sink, topk_idxs)
if h < 16:
o = o.narrow(2, 0, h).contiguous()
return o
@tilelang.jit(pass_configs=pass_configs)
def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
n = T.symbolic("n")
mix_hc = (2 + hc) * hc
threads = 64
@T.prim_func
def hc_split_sinkhorn_kernel_(
mixes: T.Tensor[(n, mix_hc), FP32],
hc_scale: T.Tensor[(3,), FP32],
hc_base: T.Tensor[(mix_hc,), FP32],
pre: T.Tensor[(n, hc), FP32],
post: T.Tensor[(n, hc), FP32],
comb: T.Tensor[(n, hc, hc), FP32],
):
with T.Kernel(n, threads=threads) as i:
mixes_shared = T.alloc_shared(mix_hc, FP32)
comb_frag = T.alloc_fragment((hc, hc), FP32)
T.copy(mixes[i, :], mixes_shared)
for j in T.Parallel(hc):
pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
for j in T.Parallel(hc):
post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
row_sum = T.alloc_fragment(hc, FP32)
col_sum = T.alloc_fragment(hc, FP32)
# comb = comb.softmax(-1) + eps
row_max = T.alloc_fragment(hc, FP32)
T.reduce_max(comb_frag, row_max, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
T.reduce_sum(comb_frag, row_sum, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(comb_frag, col_sum, dim=0)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
for _ in T.serial(sinkhorn_iters - 1):
# comb = comb / (comb.sum(-1) + eps)
T.reduce_sum(comb_frag, row_sum, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(comb_frag, col_sum, dim=0)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
T.copy(comb_frag, comb[i, :, :])
return hc_split_sinkhorn_kernel_
def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
b, s, _ = mixes.size()
pre = mixes.new_empty(b, s, hc_mult)
post = mixes.new_empty(b, s, hc_mult)
comb = mixes.new_empty(b, s, hc_mult, hc_mult)
kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
return pre, post, comb
This diff is collapsed.
#!/usr/bin/env bash
set -euo pipefail
export NCCL_ALGO="Ring"
export NCCL_PROTO="Simple"
export MP=8
export CONFIG="config.json"
export CKPT_PATH="deepseek-ai/DeepSeek-V4-Flash-bf16-mp8"
torchrun \
--nproc-per-node "${MP}" \
generate.py \
--ckpt-path "${CKPT_PATH}" \
--config "${CONFIG}" \
--interactive
# 模型唯一标识
modelCode=2397
# 模型名称
modelName=DeepSeek-V4
# 模型描述
modelDescription= DeepSeek-V4:迈向高效百万上下文智能。
# 运行过程
processType=推理
# 算法类别
appCategory=对话问答
# 框架类型
frameType=pytorch
# 加速卡类型
accelerateType=BW1100
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment