Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
885a91e7
Unverified
Commit
885a91e7
authored
Feb 14, 2025
by
Atream
Committed by
GitHub
Feb 14, 2025
Browse files
Merge pull request #294 from kvcache-ai/feat-fast-MLA
Feat fast mla
parents
67389086
1084d4e4
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
637 additions
and
211 deletions
+637
-211
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+50
-12
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+197
-197
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+379
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+2
-2
test_prompt.txt
test_prompt.txt
+9
-0
No files found.
ktransformers/models/custom_cache.py
View file @
885a91e7
...
...
@@ -51,13 +51,34 @@ class StaticCache(transformers.StaticCache):
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
if
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
key_shape
=
(
max_batch_size
,
1
,
self
.
max_cache_len
,
config
.
qk_rope_head_dim
)
value_shape
=
(
max_batch_size
,
1
,
self
.
max_cache_len
,
config
.
kv_lora_rank
)
self
.
page_size
=
64
self
.
max_pages
=
(
self
.
max_cache_len
+
self
.
page_size
-
1
)
//
self
.
page_size
latent_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
kv_lora_rank
+
config
.
qk_rope_head_dim
)
self
.
kv_lora_rank
=
config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
config
.
qk_rope_head_dim
# TODO: support real page table
self
.
page_table_map
=
dict
()
self
.
page_table_list
=
[]
for
idx
in
range
(
config
.
num_hidden_layers
):
if
isinstance
(
device
,
dict
):
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
else
:
target_device
=
device
if
target_device
not
in
self
.
page_table_map
:
page_table
=
torch
.
zeros
((
max_batch_size
,
self
.
max_pages
),
dtype
=
torch
.
int32
,
device
=
target_device
)
for
seq_id
in
range
(
max_batch_size
):
page_table
[
seq_id
,
:]
=
torch
.
arange
(
seq_id
*
self
.
max_pages
,
seq_id
*
self
.
max_pages
+
self
.
max_pages
,
dtype
=
torch
.
int32
,
device
=
target_device
)
self
.
page_table_map
[
target_device
]
=
page_table
self
.
page_table_list
.
append
(
self
.
page_table_map
[
target_device
])
self
.
is_MLA
=
True
self
.
is_page
=
True
else
:
key_shape
=
cache_shape
value_shape
=
cache_shape
self
.
is_MLA
=
False
self
.
past_tokens
=
[]
self
.
num_hidden_layers
=
config
.
num_hidden_layers
...
...
@@ -68,10 +89,17 @@ class StaticCache(transformers.StaticCache):
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
else
:
target_device
=
device
if
self
.
is_MLA
:
new_layer_key_cache
=
torch
.
zeros
(
latent_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
new_layer_value_cache
=
None
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
else
:
new_layer_key_cache
=
torch
.
zeros
(
key_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
new_layer_value_cache
=
torch
.
zeros
(
value_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_value_cache
)
self
.
key_cache
.
append
(
new_layer_key_cache
)
self
.
value_cache
.
append
(
new_layer_value_cache
)
self
.
past_tokens
.
append
(
0
)
...
...
@@ -104,10 +132,20 @@ class StaticCache(transformers.StaticCache):
cache_position
=
cache_kwargs
.
get
(
"cache_position"
)
k_out
=
self
.
key_cache
[
layer_idx
]
v_out
=
self
.
value_cache
[
layer_idx
]
self
.
past_tokens
[
layer_idx
]
+=
cache_position
.
size
(
0
)
#print(cache_position)
if
self
.
is_MLA
:
page_idx
=
cache_position
//
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out
[
page_idx
,
page_offset
,
:,
:
self
.
kv_lora_rank
]
=
key_states
k_out
[
page_idx
,
page_offset
,
:,
self
.
kv_lora_rank
:]
=
value_states
return
k_out
,
self
.
page_table_list
[
layer_idx
]
else
:
k_out
[:,
:,
cache_position
]
=
key_states
v_out
[:,
:,
cache_position
]
=
value_states
self
.
past_tokens
[
layer_idx
]
+=
cache_position
.
size
(
0
)
return
k_out
,
v_out
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
...
...
ktransformers/operators/attention.py
View file @
885a91e7
This diff is collapsed.
Click to expand it.
ktransformers/operators/triton_attention.py
0 → 100644
View file @
885a91e7
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
_fwd_grouped_kernel_stage1
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_kv_id
=
tl
.
program_id
(
2
)
if
kv_group_num
>
BLOCK_H
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:])
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
kv_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
])
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
])
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
))
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:])
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
[
None
,
:])
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
)
def
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_Seqlen
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
):
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
#if is_hip_ and Lk >= 576:
# BLOCK = 16
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
BLOCK_H
=
16
NUM_KV_SPLITS
=
num_kv_splits
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
NUM_KV_SPLITS
,
)
extra_kargs
=
{}
# TODO: support hip
"""
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
"""
_fwd_grouped_kernel_stage1
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_stages
=
2
,
Lk
=
Lk
,
Lv
=
Lv
,
**
extra_kargs
,
)
@
triton
.
jit
def
_fwd_kernel_stage2
(
Mid_O
,
o
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
if
split_kv_end
>
split_kv_start
:
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
o
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
# TODO: support hip
"""
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
"""
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
4
,
num_stages
=
2
,
**
extra_kargs
,
)
def
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_seq_len
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
\ No newline at end of file
ktransformers/util/utils.py
View file @
885a91e7
...
...
@@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)
else
:
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
long
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
...
...
@@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
,
dtype
=
torch
.
long
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
seq_length
+=
1
...
...
test_prompt.txt
0 → 100644
View file @
885a91e7
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment