Commit 48c6dc42 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

nits

parent c741387b
......@@ -72,10 +72,8 @@ def flash_mla_with_kvcache(
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Different modes (including fp8/bf16, sparsity, and model version (i.e. V3.2 or MODEL1)) has different KV cache layouts. See comments below for details.
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
Besides, some kernels also have their own requirements on the layout of k cache, including:
- For sparse fp8 decoding kernel on F3, k_cache.stride(0) must be a multiple of 656B (for V32) or 576B (for MODEL1). Padding is needed sometimes.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
......@@ -88,7 +86,7 @@ def flash_mla_with_kvcache(
Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),
where t is the k-th token of the j-th q-sequence in the i-th batch.
attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.
extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. This is used to support MODEL1. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.
For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:
......@@ -99,11 +97,6 @@ def flash_mla_with_kvcache(
- Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
- Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.
For DeepSeek MODEL1:
head_dim should be 512 while head_dim_v is also 512.
In FP8+sparse mode, every block can be divided into two parts. The first parts stores NoPE0, RoPE0, NoPE1, RoPE1, ... while the second part stores scale factors: 7xue8m0, 1Bpad, 7xue8m0, 1Bpad, ...
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment