Commit 6021ef32 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Add topk into sparse mla example and append some docs (#901)

* Remove unused `fp8_mqa_logits.py` file and update README.md to reflect new directory structure and file descriptions for deepseek_v32 example. Added sections for architecture overview, Lightning Indexer, Top-k Selector, and Sparse MLA Forward implementations.

* Update linting configurations and improve code formatting in deepseek_v32 example scripts

- Added per-file ignores for the inference directory in `pyproject.toml`.
- Refactored code in `topk_selector.py`, `convert.py`, `generate.py`, `kernel.py`, and `model.py` to enhance readability by adjusting spacing and line breaks.
- Ensured consistent formatting across function definitions and assertions for better clarity.

* Refactor test functions in deepseek_v32 example scripts for improved clarity and consistency

- Updated `fp8_lighting_indexer.py` to define a dedicated test function for the lighting indexer.
- Refactored `sparse_mla_fwd_pipelined.py` and `sparse_mla_fwd.py` to standardize test function parameters and improve readability.
- Enhanced `topk_selector.py` by introducing a test function with parameters for batch size and sequence length.
- Ensured all test functions are invoked correctly in the main execution block.

* Enhance test functions in deepseek_v32 example scripts with CUDA requirements and parameterization

- Added CUDA requirements decorators to `test_example_sparse_mla_fwd` and `test_example_sparse_mla_fwd_pipelined`.
- Parameterized test functions to use specific small shapes for testing, improving test coverage and clarity.

* lint fix

* Update README.md to correct image path for DeepSeek V3.2 architecture diagram
parent 16561159
...@@ -3,7 +3,166 @@ ...@@ -3,7 +3,166 @@
``` ```
deepseek_v32/ deepseek_v32/
├── README.md # This file ├── README.md # This file
├── fp8_mqa_logits.py # FP8 Indexer ├── figures/ # Figures and diagrams
├── inference/ # Inference implementation folder
├── fp8_lighting_indexer.py # FP8 lighting indexer
├── sparse_mla_fwd.py # Sparse MLA forward implementation ├── sparse_mla_fwd.py # Sparse MLA forward implementation
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass ├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
├── topk_selector.py # Top-k selector implementation
``` ```
## File Descriptions
### Architecture Overview
![DeepSeek V3.2 Architecture](./figures/v32_arch.png)
The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations:
1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass
### Lightning Indexer
Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector.
The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices:
```python
T.gemm(
index_k_shared,
index_q_shared,
s,
transpose_B=True,
clear_accum=True,
policy=T.GemmWarpPolicy.FullCol,
)
```
After the matmul, we apply ReLU and aggregate across heads with learned weights:
```python
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i, h_i] = (
T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]
) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
```
The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions:
```python
for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
```
The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training.
### Top-k Selector
The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process.
The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence:
```python
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
input_idx = s*BLOCK_SIZE+tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
inval_int16 = convert_to_uint16(input[bx, input_idx])
T.atomic_add(s_histogram[inval_int16], 1)
```
The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin:
```python
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
```
Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing:
```python
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True)
index[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
s_input_idx[0, pos] = input_idx
```
Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence.
### Sparse MLA Forward
The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens.
Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions:
```python
# Dense MLA: iterate over full sequence
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
# ... compute attention over this block
```
Sparse MLA only loads KV positions selected by the top-k selector:
```python
# Sparse MLA: iterate over selected indices only
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
# ... compute attention over selected tokens
```
This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid:
```python
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
```
Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA.
### Sparse MLA Forward (Pipelined)
The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines.
The key difference is splitting the warp groups into specialized roles:
```python
if tx < 128:
# Consumer 0: computes left half of output (D//2 dimensions)
# Handles QK matmul, softmax, and PV for left half
elif tx >= 128 and tx < 256:
# Consumer 1: computes right half of output (D//2 dimensions)
# Only does PV matmul for right half
elif tx >= 256:
# Producer: loads KV data from global memory
# Uses async copy with barriers to feed consumers
```
The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed:
```python
# Producer alternates between two buffers
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
# ... load KV into buffer 0
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
# ... load KV into buffer 1
T.cp_async_barrier_noinc(bar_k_1_ready[0])
```
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -258,10 +258,7 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, ...@@ -258,10 +258,7 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cost = mask.sum() cost = mask.sum()
return logits, cost return logits, cost
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
if __name__ == "__main__":
torch.manual_seed(0)
S, SKV, H, HKV, D, kv_stride = 4096, 8192, 32, 1, 64, 1
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32) weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
...@@ -304,3 +301,6 @@ if __name__ == "__main__": ...@@ -304,3 +301,6 @@ if __name__ == "__main__":
logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12 logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12
print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}") print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}")
print(f"cost_ref: {cost_ref}") print(f"cost_ref: {cost_ref}")
if __name__ == "__main__":
test_fp8_lighting_indexer()
# DeepSeek V3.2
First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count:
```bash
cd inference
export EXPERTS=256
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
```
Launch the interactive chat interface and start exploring DeepSeek's capabilities:
```bash
export CONFIG=config_671B_v3.2.json
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
```
\ No newline at end of file
{
"vocab_size": 129280,
"dim": 7168,
"inter_dim": 18432,
"moe_inter_dim": 2048,
"n_layers": 61,
"n_dense_layers": 3,
"n_heads": 128,
"n_routed_experts": 256,
"n_shared_experts": 1,
"n_activated_experts": 8,
"n_expert_groups": 8,
"n_limited_groups": 4,
"route_scale": 2.5,
"score_func": "sigmoid",
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"dtype": "fp8",
"scale_fmt": "ue8m0",
"index_n_heads": 64,
"index_head_dim": 128,
"index_topk": 2048
}
\ No newline at end of file
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate": ("gate", None),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"norm": ("norm", None),
"lm_head": ("head", 0),
"scale": ("scale", None),
"wq_b": ("wq_b", None),
"wk": ("wk", None),
"k_norm": ("k_norm", None),
"weights_proj": ("weights_proj", None),
}
def main(hf_ckpt_path, save_path, n_experts, mp):
"""
Converts and saves model checkpoint files into a specified format.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns:
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping, f"Key {key} not found in mapping"
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(
dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
import os
import json
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(
prompt_lens
) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(33377335)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print("load model")
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
print("I'm DeepSeek 👋")
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens,
tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = f.read().split("\n\n")
assert len(
prompts
) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
prompt_tokens = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True) for prompt in prompts
]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id,
temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.6)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens,
args.temperature)
import torch
import tilelang
import tilelang.language as T
from typing import Tuple, Optional
tilelang.set_log_level("WARNING")
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
}
FP8 = "float8_e4m3"
BF16 = "bfloat16"
FP32 = "float32"
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False):
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale else 2
blk_m = 32
group_size = 128
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(
T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
s_local = T.alloc_fragment((blk_m,), scale_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = s_local[i]
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(x: torch.Tensor,
block_size: int = 128,
scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), "Input tensor must be contiguous"
assert x.size(-1) % block_size == 0, (
f"Last dimension size must be divisible by block_size (block_size={block_size})")
N = x.size(-1)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
return y, s
@tilelang.jit(pass_configs=pass_configs)
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
assert out_dtype in [BF16, "float32"]
M = T.symbolic("M")
group_size = 128
block_M = 32
block_N = 128
block_K = 128
@T.prim_func
def fp8_gemm_kernel_(
A: T.Tensor[(M, K), FP8],
B: T.Tensor[(N, K), FP8],
C: T.Tensor[(M, N), out_dtype],
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), FP8)
B_shared = T.alloc_shared((block_N, block_K), FP8)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
Scale_C_shared = T.alloc_shared((block_M), FP32)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return fp8_gemm_kernel_
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor,
b_s: torch.Tensor) -> torch.Tensor:
"""
Perform a matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
assert a_s.is_contiguous() and b_s.is_contiguous(), (
"Scaling factor tensors must be contiguous")
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
kernel = fp8_gemm_kernel(N, K)
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
return c
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int):
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
blk_n1 = 512
blk_n2 = 128
@T.prim_func
def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
q_smem = T.alloc_shared((h, d), FP8)
T.copy(q[i_b, i_m, 0, 0], q_smem)
q_s_frag = T.alloc_fragment(h, FP32)
T.copy(q_s[i_b, i_m, 0], q_s_frag)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
logits = T.alloc_fragment((blk_n2, h), FP32)
T.gemm(
k_smem,
q_smem,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
logits_sum = T.alloc_fragment(blk_n2, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
return fp8_index_kernel_
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor:
"""
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from kernel import act_quant, fp8_gemm, fp8_index
world_size = 1
rank = 0
block_size = 128
@dataclass
class ModelArgs:
"""
Data class for defining model arguments and hyperparameters.
Attributes:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
scale_fmt (Optional[str]): Format for quantization scale.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
moe_inter_dim (int): Intermediate dimension for MoE layers.
n_layers (int): Number of transformer layers.
n_dense_layers (int): Number of dense layers in the model.
n_heads (int): Number of attention heads.
n_routed_experts (int): Number of routed experts for MoE layers.
n_shared_experts (int): Number of shared experts for MoE layers.
n_activated_experts (int): Number of activated experts in MoE layers.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
route_scale (float): Scaling factor for routing scores.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
mscale (float): Scaling factor for extended attention.
index_head_dim (int): Dimension for index head.
index_topk (int): Top-k for index head.
"""
max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
scale_fmt: Optional[str] = None
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
moe_inter_dim: int = 1408
n_layers: int = 27
n_dense_layers: int = 1
n_heads: int = 16
# moe
n_routed_experts: int = 64
n_shared_experts: int = 2
n_activated_experts: int = 6
n_expert_groups: int = 1
n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1.
# mla
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
# yarn
original_seq_len: int = 4096
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.
# index
index_n_heads: int = 64
index_head_dim: int = 128
index_topk: int = 2048
class ParallelEmbedding(nn.Module):
"""
Embedding layer with parallelism support across distributed processes.
Args:
vocab_size (int): Vocabulary size.
dim (int): Embedding dimension.
"""
def __init__(self, vocab_size: int, dim: int):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
self.part_vocab_size = (vocab_size // world_size)
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for parallel embedding layer.
Args:
x (torch.Tensor): Input tensor containing token indices.
Returns:
torch.Tensor: Embedded representations.
Raises:
ValueError: If `world_size` is not defined.
"""
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0
dist.all_reduce(y)
return y
def linear(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
scale_fmt: Optional[str] = None) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
and tensor formats.
Args:
x (torch.Tensor): The input tensor.
weight (torch.Tensor): The weight tensor. It may be quantized and
requires dequantization for certain cases.
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
scale_fmt (Optional[str]): The format of scaling factors.
Returns:
torch.Tensor: The result of the linear transformation, which may involve
quantization-aware computations depending on the input parameters.
Notes:
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
is used for computation.
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
"""
assert bias is None
if weight.dtype != torch.float8_e4m3fn:
return F.linear(x, weight)
else:
x, scale = act_quant(x, block_size, scale_fmt)
return fp8_gemm(x, scale, weight, weight.scale)
class Linear(nn.Module):
"""
Custom linear layer with support for quantized weights and optional bias.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
dtype = torch.bfloat16
scale_fmt: Optional[str] = None
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
else:
self.register_parameter("scale", None)
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor after linear computation.
"""
return linear(x, self.weight, self.bias, self.scale_fmt)
class ColumnParallelLinear(Linear):
"""
Linear layer with column parallelism, splitting output features across distributed processes.
Args:
in_features (int): Number of input features.
out_features (int): Total number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for column parallel linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with column-parallel computation.
"""
y = linear(x, self.weight, self.bias, self.scale_fmt)
return y
class RowParallelLinear(Linear):
"""
Linear layer with row parallelism, splitting input features across distributed processes.
Args:
in_features (int): Total number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = False,
reduce_output=True,
dtype=None):
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
self.part_in_features = in_features // world_size
self.reduce_output = reduce_output
super().__init__(self.part_in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for row parallel linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with row-parallel computation.
"""
y = linear(x, self.weight, None, self.scale_fmt)
if self.reduce_output and world_size > 1:
y = y.float()
dist.all_reduce(y)
if self.bias is not None:
y += self.bias
return y.type_as(x)
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization (RMSNorm).
Args:
dim (int): Dimension of the input tensor.
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
"""
Forward pass for RMSNorm.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized tensor with the same shape as input.
"""
dtype = x.dtype
if residual is None:
x = x.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
else:
x = residual = x.float() + residual.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype), residual.to(dtype)
class LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
"""
Precomputes frequency-based complex exponential values for rotary positional embeddings.
Args:
args (ModelArgs): Model arguments containing positional embedding parameters.
Returns:
torch.Tensor: Precomputed complex exponential values for positional embeddings.
"""
dim = args.qk_rope_head_dim
seqlen = args.max_seq_len
beta_fast = args.beta_fast
beta_slow = args.beta_slow
base = args.rope_theta
factor = args.rope_factor
def find_correction_dim(num_rotations, dim, base, max_seq_len):
"""
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
Args:
num_rotations (float): Number of rotations to compute the correction for.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
float: The correction dimension based on the input parameters.
"""
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
"""
Computes the range of correction dimensions for rotary positional embeddings.
Args:
low_rot (float): Lower bound for the number of rotations.
high_rot (float): Upper bound for the number of rotations.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min, max, dim):
"""
Computes a linear ramp function used to smooth values between a minimum and maximum range.
Args:
min (float): Minimum value for the ramp function.
max (float): Maximum value for the ramp function.
dim (int): Dimensionality of the ramp tensor.
Returns:
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
clamped to the range [0, 1].
"""
if min == max:
max += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
freqs = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth
t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
Applies rotary positional embeddings to the input tensor.
Args:
x (torch.Tensor): Input tensor with positional embeddings to be applied.
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
y = torch.view_as_real(x * freqs_cis).flatten(3)
return y.to(dtype)
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
from fast_hadamard_transform import hadamard_transform
hidden_size = x.size(-1)
return hadamard_transform(x, scale=hidden_size**-0.5)
class Indexer(torch.nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim: int = args.dim
self.n_heads: int = args.index_n_heads
self.n_local_heads = args.index_n_heads // world_size
self.head_dim: int = args.index_head_dim
self.rope_head_dim: int = args.qk_rope_head_dim
self.index_topk: int = args.index_topk
self.q_lora_rank: int = args.q_lora_rank
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
self.wk = Linear(self.dim, self.head_dim)
self.k_norm = LayerNorm(self.head_dim)
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype())
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = args.scale_fmt
self.register_buffer(
"k_cache",
torch.zeros(
args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn),
persistent=False)
self.register_buffer(
"k_scale_cache",
torch.zeros(
args.max_batch_size,
args.max_seq_len,
self.head_dim // block_size,
dtype=torch.float32),
persistent=False)
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
q = self.wq_b(qr)
q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim)
q_pe, q_nope = torch.split(
q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
q = torch.cat([q_pe, q_nope], dim=-1)
k = self.wk(x)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1)
q = rotate_activation(q)
k = rotate_activation(k)
q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
weights = self.weights_proj(x) * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
index_score = fp8_index(q_fp8.contiguous(), weights,
self.k_cache[:bsz, :end_pos].contiguous(),
self.k_scale_cache[:bsz, :end_pos].contiguous())
if mask is not None:
index_score += mask
topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
topk_indices_ = topk_indices.clone()
dist.broadcast(topk_indices_, src=0)
assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
return topk_indices
def weight_dequant(weight, scale):
shape = weight.shape
assert weight.dim() == 2
weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size,
block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size)
weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(
shape[0] // block_size, shape[1] // block_size, block_size,
block_size).transpose(1, 2).contiguous().view(shape)
return weight
class MLA(nn.Module):
"""
Multi-Head Latent Attention (MLA) Layer.
Attributes:
dim (int): Dimensionality of the input features.
n_heads (int): Number of attention heads.
n_local_heads (int): Number of local attention heads for distributed systems.
q_lora_rank (int): Rank for low-rank query projection.
kv_lora_rank (int): Rank for low-rank key/value projection.
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
qk_head_dim (int): Total dimensionality of query/key projections.
v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation.
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.qk_nope_head_dim = args.qk_nope_head_dim
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
self.v_head_dim = args.v_head_dim
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank,
self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim**-0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
self.indexer = Indexer(args)
self.register_buffer(
"kv_cache",
torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
persistent=False)
self.register_buffer(
"pe_cache",
torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim),
persistent=False)
self.dequant_wkv_b = None
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor]):
"""
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
start_pos (int): Starting position in the sequence for caching.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
qr = self.q_norm(self.wq_a(x))
q = self.wq_b(qr)
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv = self.kv_norm(kv)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
self.kv_cache[:bsz, start_pos:end_pos] = kv
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
if mask is not None: # MHA prefill
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(kv)
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale
# indexer
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"),
device=x.device).scatter_(-1, topk_indices, 0)
index_mask += mask
scores += index_mask.unsqueeze(2)
scores = scores.softmax(dim=-1, dtype=torch.float32)
x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
else: # MHA decode
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
scores = (torch.einsum("bshc,btc->bsht", q_nope.float(),
self.kv_cache[:bsz, :end_pos].float()) +
torch.einsum("bshr,btr->bsht", q_pe.float(),
self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale
# indexer
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
index_mask = torch.full((bsz, 1, end_pos), float("-inf"),
device=x.device).scatter_(-1, topk_indices, 0)
scores += index_mask.unsqueeze(2)
scores = scores.softmax(dim=-1, dtype=torch.float32)
x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x
class MLP(nn.Module):
"""
Multi-Layer Perceptron (MLP) used as a feed-forward layer.
Attributes:
w1 (nn.Module): Linear layer for input-to-hidden transformation.
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
"""
Initializes the MLP layer.
Args:
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.w1 = ColumnParallelLinear(dim, inter_dim)
self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output)
self.w3 = ColumnParallelLinear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MLP layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after MLP computation.
"""
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
class Gate(nn.Module):
"""
Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
Attributes:
dim (int): Dimensionality of input features.
topk (int): Number of top experts activated for each input.
n_groups (int): Number of groups for routing.
topk_groups (int): Number of groups to route inputs to.
score_func (str): Scoring function ('softmax' or 'sigmoid').
route_scale (float): Scaling factor for routing weights.
weight (torch.nn.Parameter): Learnable weights for the gate.
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Gate module.
Args:
args (ModelArgs): Model arguments containing gating parameters.
"""
super().__init__()
self.dim = args.dim
self.topk = args.n_activated_experts
self.n_groups = args.n_expert_groups
self.topk_groups = args.n_limited_groups
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts,
dtype=torch.float32)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for the gating mechanism.
Args:
x (torch.Tensor): Input tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
"""
scores = linear(x.float(), self.weight.float())
if self.score_func == "softmax":
scores = scores.softmax(dim=-1)
else:
scores = scores.sigmoid()
original_scores = scores
if self.bias is not None:
scores = scores + self.bias
if self.n_groups > 1:
scores = scores.view(x.size(0), self.n_groups, -1)
if self.bias is None:
group_scores = scores.amax(dim=-1)
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
indices = scores.topk(self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func == "sigmoid":
weights /= weights.sum(dim=-1, keepdim=True)
weights *= self.route_scale
return weights, indices
class Expert(nn.Module):
"""
Expert layer for Mixture-of-Experts (MoE) models.
Attributes:
w1 (nn.Module): Linear layer for input-to-hidden transformation.
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int):
"""
Initializes the Expert layer.
Args:
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.w1 = Linear(dim, inter_dim)
self.w2 = Linear(inter_dim, dim)
self.w3 = Linear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the Expert layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert computation.
"""
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
class MoE(nn.Module):
"""
Mixture-of-Experts (MoE) module.
Attributes:
dim (int): Dimensionality of input features.
n_routed_experts (int): Total number of experts in the model.
n_local_experts (int): Number of experts handled locally in distributed systems.
n_activated_experts (int): Number of experts activated for each input.
gate (nn.Module): Gating mechanism to route inputs to experts.
experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the MoE module.
Args:
args (ModelArgs): Model arguments containing MoE parameters.
"""
super().__init__()
self.dim = args.dim
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args)
self.experts = nn.ModuleList([
Expert(args.dim, args.moe_inter_dim)
if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)
])
self.shared_experts = MLP(
args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert routing and computation.
"""
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x, dtype=torch.float32)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
y += self.shared_experts(x)
if world_size > 1:
dist.all_reduce(y)
return y.type_as(x).view(shape)
class Block(nn.Module):
"""
Transformer block combining attention and feed-forward layers.
Attributes:
attn (nn.Module): Attention layer (MLA).
ffn (nn.Module): Feed-forward network (MLP or MoE).
attn_norm (nn.Module): Layer normalization for attention.
ffn_norm (nn.Module): Layer normalization for feed-forward network.
"""
def __init__(self, layer_id: int, args: ModelArgs):
"""
Initializes the Transformer block.
Args:
layer_id (int): Layer index in the transformer.
args (ModelArgs): Model arguments containing block parameters.
"""
super().__init__()
self.attn = MLA(args)
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
self.attn_norm = RMSNorm(args.dim)
self.ffn_norm = RMSNorm(args.dim)
def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int,
freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
"""
Forward pass for the Transformer block.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position in the sequence.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor after block computation.
"""
if residual is None:
x, residual = self.attn_norm(x), x
else:
x, residual = self.attn_norm(x, residual)
x = self.attn(x, start_pos, freqs_cis, mask)
x, residual = self.ffn_norm(x, residual)
x = self.ffn(x)
return x, residual
class Transformer(nn.Module):
"""
Transformer model with positional embeddings, multiple layers, and output projection.
Attributes:
max_seq_len (int): Maximum sequence length for the transformer.
embed (nn.Module): Embedding layer for input tokens.
layers (torch.nn.ModuleList): List of transformer blocks.
norm (nn.Module): Layer normalization applied after all blocks.
head (nn.Module): Output projection layer mapping to vocabulary size.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Transformer model.
Args:
args (ModelArgs): Model arguments containing transformer parameters.
"""
global world_size, rank
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
Linear.scale_fmt = args.scale_fmt
super().__init__()
self.max_seq_len = args.max_seq_len
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim)
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
"""
Forward pass for the Transformer model.
Args:
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
seqlen = tokens.size(1)
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
h, residual = self.embed(tokens), None
for layer in self.layers:
h, residual = layer(h, residual, start_pos, freqs_cis, mask)
h, _ = self.norm(h, residual)
logits = self.head(h[:, -1].float())
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
return logits
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.manual_seed(0)
args = ModelArgs()
x = torch.randint(0, args.vocab_size, (2, 128))
model = Transformer(args)
print(model(x).size())
torch
transformers
safetensors
fast_hadamard_transform
tilelang==0.1.6
\ No newline at end of file
...@@ -231,19 +231,15 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): ...@@ -231,19 +231,15 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
return o.to(torch.bfloat16) return o.to(torch.bfloat16)
def test_sparse_mla_fwd(): def test_sparse_mla_fwd(B=1,
B, S, SKV, H, HKV, DQK, DV, topk, dtype = ( S=4096,
1, SKV=4096,
4096, H=128,
32768, HKV=1,
128, DQK=576,
1, DV=512,
576, topk=2048,
512, dtype=torch.bfloat16):
2048,
torch.bfloat16,
)
torch.random.manual_seed(0) torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
...@@ -273,4 +269,5 @@ def test_sparse_mla_fwd(): ...@@ -273,4 +269,5 @@ def test_sparse_mla_fwd():
if __name__ == "__main__": if __name__ == "__main__":
test_sparse_mla_fwd() test_sparse_mla_fwd(
B=1, S=4096, SKV=32768, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)
...@@ -397,14 +397,17 @@ def ref_sparse_mla_fwd_interface(q, ...@@ -397,14 +397,17 @@ def ref_sparse_mla_fwd_interface(q,
return o.to(torch.bfloat16) return o.to(torch.bfloat16)
def test_sparse_mla_fwd(test_correctness=False): def test_sparse_mla_fwd_pipelined(B=1,
S=4096,
SKV=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
q_start_s_index=1024):
KV_stride = 1 KV_stride = 1
if test_correctness:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
q_start_s_index = 1024
else:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
q_start_s_index = 4096 * 64
torch.random.manual_seed(0) torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10
...@@ -426,14 +429,14 @@ def test_sparse_mla_fwd(test_correctness=False): ...@@ -426,14 +429,14 @@ def test_sparse_mla_fwd(test_correctness=False):
def fn(): def fn():
out, lse = kernel(q, kv, indices, q_start_s_index_t) out, lse = kernel(q, kv, indices, q_start_s_index_t)
if q_start_s_index == 0 and kv_stride > 1: if q_start_s_index == 0 and KV_stride > 1:
out[:, :kv_stride - 1, :, :] = 0 out[:, :KV_stride - 1, :, :] = 0
return out, lse return out, lse
tl_out, tl_lse = fn() tl_out, tl_lse = fn()
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride)
print(f"tl_out: {tl_out}") # print(f"tl_out: {tl_out}")
print(f"ref_out: {ref_out}") # print(f"ref_out: {ref_out}")
torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3)
...@@ -452,4 +455,9 @@ if __name__ == "__main__": ...@@ -452,4 +455,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--test_correctness", action="store_true") parser.add_argument("--test_correctness", action="store_true")
args = parser.parse_args() args = parser.parse_args()
test_sparse_mla_fwd(args.test_correctness) if args.test_correctness:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
else:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
# ruff: noqa
import tilelang.testing
from topk_selector import test_topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
def test_example_topk_selector():
test_topk_selector()
def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd():
# small shapes for testing
test_sparse_mla_fwd(S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing
test_sparse_mla_fwd_pipelined(S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256)
if __name__ == "__main__":
tilelang.testing.main()
import torch
import tilelang
import tilelang.language as T
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
}
def convert_to_uint16(x):
hval = T.Cast("float16", x)
bits_uint = T.reinterpret("uint16", hval)
bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000))
return bits_uint >> 8
def convert_to_uint32(x):
bits_uint = T.reinterpret("uint32", x)
bits_uint = T.if_then_else(
x < 0,
~bits_uint & T.Cast("uint32", (0xFFFFFFFF)),
bits_uint | T.Cast("uint32", (0x80000000)),
)
return bits_uint
@tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
RADIX = 1 << 8
BLOCK_SIZE = 1024
SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K
@T.prim_func
def tl_topk_kernel(
input: T.Tensor[(batch, seq_len), in_dtype],
index: T.Tensor[(batch, topk), out_dtype],
starts: T.Tensor[(batch), out_dtype],
ends: T.Tensor[(batch), out_dtype],
):
with T.Kernel(batch, threads=BLOCK_SIZE) as (bx):
tx = T.get_thread_binding()
s_threshold_bin_id = T.alloc_shared([1], "int32")
s_histogram = T.alloc_shared([RADIX + 1], "int32")
s_num_input = T.alloc_shared([2], "int32")
s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32")
l_threshold_bin_id = T.alloc_var("int32")
l_new_topk = T.alloc_var("int32")
l_num_input = T.alloc_var("int32")
l_bin_id32 = T.alloc_var("int32")
l_val = T.alloc_var("int32")
l_start_pos = T.alloc_var("int32")
l_start_idx = T.alloc_var("int32")
l_end_idx = T.alloc_var("int32")
l_out_pos = T.alloc_var("int32")
l_new_topk = topk
l_start_idx = starts[bx]
l_end_idx = ends[bx]
# stage 1: use 8bit to do quick topk
T.fill(s_histogram, 0)
T.fill(s_num_input[0], 0)
T.sync_threads()
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
inval_int16 = convert_to_uint16(input[bx, input_idx])
T.atomic_add(s_histogram[inval_int16], 1)
T.sync_threads()
# cumsum
if tx < RADIX:
for i in T.serial(8):
offset = 1 << i
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
l_val = s_histogram[tx] + s_histogram[tx + offset]
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
s_histogram[tx] = l_val
# find threshold bin id
T.sync_threads(3, RADIX)
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
T.sync_threads()
l_threshold_bin_id = s_threshold_bin_id[0]
l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
T.sync_threads()
# collect all elements with exponent ≥ threshold
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
T.sync_threads()
input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
bin_id = convert_to_uint16(input[bx, input_idx])
l_bin_id32 = T.Cast("int32", bin_id)
if l_bin_id32 > l_threshold_bin_id:
# need a pos = T.atomic_add(s_histogram[bin_id32+1], 1)
pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True)
index[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
# pos = s_num_input[0]
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
s_input_idx[0, pos] = input_idx
# stage 2: tail pass
for round in T.serial(4):
if l_new_topk <= 0:
T.loop_break()
r_idx = round % 2
l_start_pos = topk - l_new_topk
T.sync_threads()
T.fill(s_histogram, 0)
if tx == 0:
s_num_input[r_idx ^ 1] = 0
T.sync_threads()
l_num_input = s_num_input[r_idx]
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", ((
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
(24 - round * 8)) & 0xFF))
T.atomic_add(s_histogram[l_bin_id32], 1)
T.sync_threads()
# cumsum
if tx < RADIX:
for i in T.serial(8):
offset = 1 << i
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
l_val = s_histogram[tx] + s_histogram[tx + offset]
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
s_histogram[tx] = l_val
# find threshold bin id
T.sync_threads(3, RADIX)
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
T.sync_threads()
l_threshold_bin_id = s_threshold_bin_id[0]
l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
T.sync_threads()
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
T.sync_threads()
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", ((
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
(24 - round * 8)) & 0xFF))
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
if round == 3:
l_out_pos = T.atomic_add(
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
if l_out_pos < topk:
index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
else:
pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True)
s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx,
s * BLOCK_SIZE + tx]
return tl_topk_kernel
def tl_topk(input, starts, ends, topk):
batch, seq_len = input.shape
indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)
kernel = tl_topk_impl(topk)
kernel(input, indexes, starts, ends)
return indexes
def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
batch = 64
seq_len = 32 * 1024
topk = 2048
torch.manual_seed(1)
input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()
starts = torch.zeros(batch, dtype=torch.int32).cuda()
ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len
indexes = tl_topk(input, starts, ends, topk)
print(indexes)
indexes_ref = torch.topk(input, topk, dim=-1)[1]
print(indexes_ref)
# indexes_ref = fast_topk(input, topk)
# print(indexes_ref)
# Calculate intersection of out_ref and out_trt
for i in range(batch):
ref_np = indexes_ref[i].cpu().to(torch.int32).numpy()
trt_np = indexes[i].cpu().to(torch.int32).numpy()
set_ref = set(ref_np)
set_trt = set(trt_np)
intersection = set_ref & set_trt
print("selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))
# Performance test with CUDA events
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
_ = tl_topk(input, starts, ends, topk)
torch.cuda.synchronize()
n_iters = 20
start_event.record()
for _ in range(n_iters):
_ = tl_topk(input, starts, ends, topk)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms")
# Torch topk time
start_event.record()
for _ in range(n_iters):
_ = torch.topk(input, topk, dim=-1)[1]
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms")
if __name__ == "__main__":
test_topk_selector()
...@@ -57,3 +57,4 @@ ignore = [ ...@@ -57,3 +57,4 @@ ignore = [
] ]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"3rdparty/**/*" = ["ALL"] "3rdparty/**/*" = ["ALL"]
"examples/deepseek_v32/inference/**/*" = ["ALL"]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment