"vscode:/vscode.git/clone" did not exist on "4339af758e012a06fda296ad2884683d615225fc"
README.md 7.52 KB
Newer Older
1
2
3
4
5
## Directory Structure

```
deepseek_v32/
├── README.md                           # This file
6
7
8
├── figures/                            # Figures and diagrams
├── inference/                          # Inference implementation folder
├── fp8_lighting_indexer.py             # FP8 lighting indexer
9
10
├── sparse_mla_fwd.py                   # Sparse MLA forward implementation
├── sparse_mla_fwd_pipelined.py         # Pipelined implementation of sparse MLA forward pass
11
├── topk_selector.py                    # Top-k selector implementation
12
```
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

## File Descriptions

### Architecture Overview

![DeepSeek V3.2 Architecture](./figures/v32_arch.png)

The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations:

1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass

### Lightning Indexer

Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector.

The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices:

```python
T.gemm(
    index_k_shared,
    index_q_shared,
    s,
    transpose_B=True,
    clear_accum=True,
    policy=T.GemmWarpPolicy.FullCol,
)
```

After the matmul, we apply ReLU and aggregate across heads with learned weights:

```python
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
    s_reshaped[bn_i, bq_i, h_i] = (
        T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]
    ) * index_k_scale_fragment[bn_i]

T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
```

The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions:

```python
for bq_i in T.serial(block_Q):
    cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
for bq_i in T.serial(block_Q):
    cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
```

The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training.

### Top-k Selector

The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process.

The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence:

```python
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
    input_idx = s*BLOCK_SIZE+tx
    if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
        inval_int16 = convert_to_uint16(input[bx, input_idx])
        T.atomic_add(s_histogram[inval_int16], 1)
```

The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin:

```python
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
    s_threshold_bin_id[0] = tx
```

Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing:

```python
if l_bin_id32 > l_threshold_bin_id:
    pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True)
    index[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
    pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
    s_input_idx[0, pos] = input_idx
```

Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence.

### Sparse MLA Forward

The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens.

Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions:

```python
# Dense MLA: iterate over full sequence
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
    T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
    # ... compute attention over this block
```

Sparse MLA only loads KV positions selected by the top-k selector:

```python
# Sparse MLA: iterate over selected indices only
for i_i in T.Pipelined(NI, num_stages=num_stages):
    for bi_i, d_i in T.Parallel(BI, D):
        KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
    # ... compute attention over selected tokens
```

This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid:

```python
for bi_i in T.Parallel(BI):
    mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
```

Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA.

### Sparse MLA Forward (Pipelined)

The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines.

The key difference is splitting the warp groups into specialized roles:

```python
if tx < 128:
    # Consumer 0: computes left half of output (D//2 dimensions)
    # Handles QK matmul, softmax, and PV for left half

elif tx >= 128 and tx < 256:
    # Consumer 1: computes right half of output (D//2 dimensions)
    # Only does PV matmul for right half

elif tx >= 256:
    # Producer: loads KV data from global memory
    # Uses async copy with barriers to feed consumers
```

The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed:

```python
# Producer alternates between two buffers
for i_i in T.serial(T.ceildiv(NI, 2)):
    # Buffer 0
    T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
    # ... load KV into buffer 0
    T.cp_async_barrier_noinc(bar_k_0_ready[0])

    # Buffer 1
    T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
    # ... load KV into buffer 1
    T.cp_async_barrier_noinc(bar_k_1_ready[0])
```

Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.