Commit bb35dc5b authored by Atream's avatar Atream
Browse files

init support for MLA using Attention kernel

parent 62011fd6
...@@ -51,13 +51,33 @@ class StaticCache(transformers.StaticCache): ...@@ -51,13 +51,33 @@ class StaticCache(transformers.StaticCache):
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM": if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim) self.page_size = 64
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim) self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
key_shape = (max_batch_size, 1, self.max_cache_len, config.qk_rope_head_dim) key_shape = (self.max_pages, self.page_size, 1, config.qk_rope_head_dim)
value_shape = (max_batch_size, 1, self.max_cache_len, config.kv_lora_rank) value_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank)
# TODO: support real page table
self.page_table_map = dict()
self.page_table_list = []
for idx in range(config.num_hidden_layers):
if isinstance(device, dict):
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device
if target_device not in self.page_table_map:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
self.page_table_map[target_device] = page_table
self.page_table_list.append(self.page_table_map[target_device])
self.is_MLA = True
self.is_page = True
else: else:
key_shape = cache_shape key_shape = cache_shape
value_shape = cache_shape value_shape = cache_shape
self.is_MLA = False
self.past_tokens = [] self.past_tokens = []
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
...@@ -104,11 +124,20 @@ class StaticCache(transformers.StaticCache): ...@@ -104,11 +124,20 @@ class StaticCache(transformers.StaticCache):
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
#print(cache_position)
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.past_tokens[layer_idx] += cache_position.size(0) self.past_tokens[layer_idx] += cache_position.size(0)
return k_out, v_out #print(cache_position)
if self.is_MLA:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out[page_idx, page_offset, ...] = key_states
v_out[page_idx, page_offset, ...] = value_states
return k_out, v_out, self.page_table_list[layer_idx]
else:
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model.""" """Returns the sequence length of the cached states that were seen by the model."""
......
This diff is collapsed.
import triton
import triton.language as tl
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_grouped_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2)
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
None, :]
q = tl.load(Q + offs_q,
mask=(mask_h[:, None]) & (mask_d[None, :]),
other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
offs_dpe[None, :])
qpe = tl.load(Q + off_qpe,
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
cur_kv_head * stride_buf_kh + offs_d[:, None])
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
cur_kv_head * stride_buf_kh +
offs_dpe[:, None])
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) &
(mask_dpe[:, None]),
other=0.0,
)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
qk, float("-inf"))
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
cur_kv_head * stride_buf_vh + offs_dv[None, :])
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = (cur_batch * stride_mid_ob +
cur_head[:, None] * stride_mid_oh +
split_kv_id * stride_mid_os + offs_dv[None, :])
tl.store(
Att_Out + offs_mid_o,
acc / e_sum[:, None],
mask=(mask_h[:, None]) & (mask_dv[None, :]),
)
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
split_kv_id * stride_mid_os + Lv)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
mask=mask_h,
)
def _decode_grouped_att_m_fwd(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
#if is_hip_ and Lk >= 576:
# BLOCK = 16
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
NUM_KV_SPLITS,
)
extra_kargs = {}
# TODO: support hip
"""
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
"""
_fwd_grouped_kernel_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=4,
num_stages=2,
Lk=Lk,
Lv=Lv,
**extra_kargs,
)
@triton.jit
def _fwd_kernel_stage2(
Mid_O,
o,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
cur_batch_seq_len)
if split_kv_end > split_kv_start:
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,
mask=mask_d,
other=0.0)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
def _decode_softmax_reducev_fwd(
logits,
q,
o,
v_buffer,
b_seq_len,
num_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
NUM_KV_SPLITS = num_kv_splits
extra_kargs = {}
# TODO: support hip
"""
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
"""
grid = (batch, head_num)
_fwd_kernel_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
num_warps=4,
num_stages=2,
**extra_kargs,
)
def decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_grouped_att_m_fwd(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
num_kv_splits)
\ No newline at end of file
...@@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud ...@@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
) )
else: else:
past_key_values = None past_key_values = None
cache_position = torch.arange(seq_length, device=torch_device) cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long)
generated_ids = torch.zeros( generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
) )
...@@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud ...@@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids[:, seq_length] = next_token generated_ids[:, seq_length] = next_token
tokens.append(int(next_token)) tokens.append(int(next_token))
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device) cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long)
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
seq_length += 1 seq_length += 1
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment