Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
bb35dc5b
Commit
bb35dc5b
authored
Feb 13, 2025
by
Atream
Browse files
init support for MLA using Attention kernel
parent
62011fd6
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
531 additions
and
242 deletions
+531
-242
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+37
-8
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+104
-232
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+379
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+2
-2
test_prompt.txt
test_prompt.txt
+9
-0
No files found.
ktransformers/models/custom_cache.py
View file @
bb35dc5b
...
@@ -51,13 +51,33 @@ class StaticCache(transformers.StaticCache):
...
@@ -51,13 +51,33 @@ class StaticCache(transformers.StaticCache):
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
if
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
self
.
page_size
=
64
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
self
.
max_pages
=
(
self
.
max_cache_len
+
self
.
page_size
-
1
)
//
self
.
page_size
key_shape
=
(
max_batch_size
,
1
,
self
.
max_cache_len
,
config
.
qk_rope_head_dim
)
key_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
qk_rope_head_dim
)
value_shape
=
(
max_batch_size
,
1
,
self
.
max_cache_len
,
config
.
kv_lora_rank
)
value_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
kv_lora_rank
)
# TODO: support real page table
self
.
page_table_map
=
dict
()
self
.
page_table_list
=
[]
for
idx
in
range
(
config
.
num_hidden_layers
):
if
isinstance
(
device
,
dict
):
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
else
:
target_device
=
device
if
target_device
not
in
self
.
page_table_map
:
page_table
=
torch
.
zeros
((
max_batch_size
,
self
.
max_pages
),
dtype
=
torch
.
int32
,
device
=
target_device
)
for
seq_id
in
range
(
max_batch_size
):
page_table
[
seq_id
,
:]
=
torch
.
arange
(
seq_id
*
self
.
max_pages
,
seq_id
*
self
.
max_pages
+
self
.
max_pages
,
dtype
=
torch
.
int32
,
device
=
target_device
)
self
.
page_table_map
[
target_device
]
=
page_table
self
.
page_table_list
.
append
(
self
.
page_table_map
[
target_device
])
self
.
is_MLA
=
True
self
.
is_page
=
True
else
:
else
:
key_shape
=
cache_shape
key_shape
=
cache_shape
value_shape
=
cache_shape
value_shape
=
cache_shape
self
.
is_MLA
=
False
self
.
past_tokens
=
[]
self
.
past_tokens
=
[]
self
.
num_hidden_layers
=
config
.
num_hidden_layers
self
.
num_hidden_layers
=
config
.
num_hidden_layers
...
@@ -104,10 +124,19 @@ class StaticCache(transformers.StaticCache):
...
@@ -104,10 +124,19 @@ class StaticCache(transformers.StaticCache):
cache_position
=
cache_kwargs
.
get
(
"cache_position"
)
cache_position
=
cache_kwargs
.
get
(
"cache_position"
)
k_out
=
self
.
key_cache
[
layer_idx
]
k_out
=
self
.
key_cache
[
layer_idx
]
v_out
=
self
.
value_cache
[
layer_idx
]
v_out
=
self
.
value_cache
[
layer_idx
]
self
.
past_tokens
[
layer_idx
]
+=
cache_position
.
size
(
0
)
#print(cache_position)
#print(cache_position)
if
self
.
is_MLA
:
page_idx
=
cache_position
//
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out
[
page_idx
,
page_offset
,
...]
=
key_states
v_out
[
page_idx
,
page_offset
,
...]
=
value_states
return
k_out
,
v_out
,
self
.
page_table_list
[
layer_idx
]
else
:
k_out
[:,
:,
cache_position
]
=
key_states
k_out
[:,
:,
cache_position
]
=
key_states
v_out
[:,
:,
cache_position
]
=
value_states
v_out
[:,
:,
cache_position
]
=
value_states
self
.
past_tokens
[
layer_idx
]
+=
cache_position
.
size
(
0
)
return
k_out
,
v_out
return
k_out
,
v_out
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
...
...
ktransformers/operators/attention.py
View file @
bb35dc5b
...
@@ -21,9 +21,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
...
@@ -21,9 +21,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
import
logging
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_with_kvcache
,
flash_attn_func
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
logger
=
logging
.
getLogger
(
"attention"
)
logger
=
logging
.
getLogger
(
"attention"
)
class
KDeepseekV3Attention
(
BaseInjectedModule
,
DeepseekV3Attention
):
# V3 MLA is same to V2
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -39,7 +43,6 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -39,7 +43,6 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
self
.
orig_module
.
__init__
(
orig_module
.
config
,
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
softmax_scale
=
self
.
q_head_dim
**
(
-
0.5
)
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
...
@@ -52,7 +55,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -52,7 +55,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
self
.
out_absorb
.
weight
.
data
=
out_absorb
del
self
.
orig_module
.
kv_b_proj
#
del self.orig_module.kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
return
q_absorb
,
out_absorb
...
@@ -96,7 +99,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -96,7 +99,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
_v3
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
...
@@ -151,7 +154,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -151,7 +154,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
attn_weights
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward
(
def
forward
(
self
,
self
,
...
@@ -164,109 +167,9 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -164,109 +167,9 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
"padding_mask"
in
kwargs
:
warnings
.
warn
(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
q_len
<=
self
.
chunck_size
:
return
self
.
forward_chunck
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
)
assert
output_attentions
==
False
,
"output_attentions is not supported when using chunked attention"
attn_output
=
None
attn_weight
=
None
cur_idx
=
0
while
cur_idx
<
q_len
:
if
attention_mask
is
not
None
:
chunk_mask
=
attention_mask
[:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...]
else
:
# generate chunk_mask automatically.
self
.
attn_mask
=
\
torch
.
zeros
(
1
,
1
,
self
.
chunck_size
,
past_key_value
.
max_cache_len
,
device
=
hidden_states
.
device
)
\
if
self
.
attn_mask
is
None
\
else
self
.
attn_mask
self
.
attn_mask
[:,
:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
past_key_value
.
max_cache_len
)]
=
\
-
1e+38
*
torch
.
triu
(
torch
.
ones
(
self
.
chunck_size
,
self
.
chunck_size
,
device
=
hidden_states
.
device
),
diagonal
=
1
)
\
[:,:
min
(
self
.
chunck_size
,
min
(
past_key_value
.
max_cache_len
-
cur_idx
,
self
.
chunck_size
))]
self
.
attn_mask
[:,
:,
:,
cur_idx
+
self
.
chunck_size
:]
=
-
1e+38
self
.
attn_mask
[:,
:,
:,
:
cur_idx
]
=
0
chunk_mask
=
torch
.
narrow
(
self
.
attn_mask
,
2
,
0
,
min
(
self
.
chunck_size
,
q_len
-
cur_idx
))
cur_output
,
cur_attn_weight
=
self
.
forward_chunck
(
hidden_states
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...],
chunk_mask
,
position_ids
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
past_key_value
,
output_attentions
,
use_cache
,
cache_position
[
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
**
kwargs
)
cur_idx
+=
self
.
chunck_size
if
attn_output
is
None
:
attn_output
=
cur_output
attn_weight
=
cur_attn_weight
else
:
attn_output
=
torch
.
cat
((
attn_output
,
cur_output
),
dim
=-
2
)
attn_weight
=
torch
.
cat
((
attn_weight
,
cur_attn_weight
),
dim
=-
2
)
return
attn_output
,
attn_weight
,
past_key_value
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
qk_nope_head_dim
,
bias
=
False
,
dtype
=
q_absorb
.
dtype
,
device
=
q_absorb
.
device
)
self
.
q_absorb
.
weight
.
data
=
q_absorb
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
del
self
.
orig_module
.
kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
def
forward_chunck
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
q_lora_rank
is
None
:
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
)
q
=
self
.
q_proj
(
hidden_states
)
else
:
else
:
...
@@ -282,140 +185,111 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -282,140 +185,111 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
)
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
).
transpose
(
1
,
2
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
).
transpose
(
1
,
2
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
kv_seq_len
=
k_pe
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
k_pe
=
k_pe
.
transpose
(
1
,
2
)
# [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if
q_len
==
1
:
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv
=
compressed_kv
.
unsqueeze
(
1
)
k_pe
,
compressed_kv
,
page_table
=
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
,
compressed_kv
=
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
=
compressed_kv
.
squeeze
(
1
)
#if cache_position is not None:
# compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:]
# k_pe = k_pe[:,:,: cache_position[-1] + 1,:]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
attn_weights
=
(
torch
.
matmul
(
q_pe
,
k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
compressed_kv
.
unsqueeze
(
-
3
).
mT
))
*
self
.
softmax_scale
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
"""
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
raise ValueError(
# q_nope [bsz, self.num_heads, q_len, self.kv_lora_rank]
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
f" {attn_weights.size()}"
query_states
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
# k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# compressed_kv [bsz, q_len, 1, self.kv_lora_rank]
key_states
=
torch
.
cat
([
compressed_kv
,
k_pe
],
dim
=-
1
)
query_states
=
query_states
.
squeeze
(
2
)
attn_output
=
torch
.
zeros_like
(
q_nope
)
attn_logits
=
torch
.
empty
(
(
bsz
,
self
.
num_heads
,
1
,
#num_kv_splits # follow vLLM, fix it TODO
self
.
kv_lora_rank
+
1
,
),
dtype
=
torch
.
float32
,
device
=
attn_output
.
device
)
)
assert attention_mask is not None
"""
"""
if
attention_mask
is
not
None
:
print("query_states", torch.isnan(query_states).any())
print("key_states", torch.isnan(key_states[:,:,0,:]).any())
print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any())
print("position_ids", torch.isnan(position_ids).any())
"""
"""
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
"""
#causal_mask = attention_mask[:, :, :, : kv_seq_len]
attn_weights
=
attn_weights
+
attention_mask
# upc
as
t
att
ention to fp32
# fl
as
h
att
n doesn't support head_dim bigger than 256
attn_weights
=
nn
.
functional
.
softmax
(
# use vLLM triton attention kernel for MQA
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
decode_attention_fwd_grouped
(
query_states
,
key_states
,
compressed_kv
,
attn_output
,
).
to
(
q_pe
.
dtype
)
page_table
,
attn_weights
=
nn
.
functional
.
dropout
(
position_ids
.
squeeze
(
0
).
to
(
torch
.
int32
),
attn_logits
,
attn_weights
,
p
=
self
.
attention_dropout
,
training
=
self
.
training
1
,
#num_kv_splits # follow vLLM, fix it TODO
)
self
.
softmax_scale
,
attn_output
=
torch
.
einsum
(
'bhql,blc->bhqc'
,
attn_weights
,
compressed_kv
)
past_key_value
.
page_size
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
v_head_dim
):
raise
ValueError
(
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
v_head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
#print("attn_output", torch.isnan(attn_output).any())
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
else
:
def
forward
(
if
past_key_value
is
not
None
:
self
,
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
hidden_states
:
torch
.
Tensor
,
k_pe
.
squeeze
(
0
)
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
compressed_kv
.
squeeze
(
0
)
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
past_key_value
:
Optional
[
Cache
]
=
None
,
k_pe
.
unsqueeze
(
0
)
output_attentions
:
bool
=
False
,
compressed_kv
.
unsqueeze
(
0
)
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
k_pe
=
k_pe
[:,
:
q_len
]
**
kwargs
,
compressed_kv
=
compressed_kv
[:,
:
q_len
]
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
kv
=
(
if
"padding_mask"
in
kwargs
:
self
.
kv_b_proj
(
compressed_kv
)
warnings
.
warn
(
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
q_head_dim
)
if
q_len
<=
self
.
chunck_size
:
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
return
self
.
forward_chunck
(
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
hidden_states
,
attention_mask
,
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
position_ids
,
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
past_key_value
,
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
output_attentions
,
use_cache
,
query_states
=
query_states
.
transpose
(
1
,
2
)
cache_position
,
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
**
kwargs
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states_padded
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
)
)
assert
output_attentions
==
False
,
"output_attentions is not supported when using chunked attention"
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
None
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
cur_idx
=
0
while
cur_idx
<
q_len
:
if
attention_mask
is
not
None
:
chunk_mask
=
attention_mask
[:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...]
else
:
# generate chunk_mask automatically.
self
.
attn_mask
=
\
torch
.
zeros
(
1
,
1
,
self
.
chunck_size
,
past_key_value
.
max_cache_len
,
device
=
hidden_states
.
device
)
\
if
self
.
attn_mask
is
None
\
else
self
.
attn_mask
self
.
attn_mask
[:,
:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
past_key_value
.
max_cache_len
)]
=
\
-
1e+38
*
torch
.
triu
(
torch
.
ones
(
self
.
chunck_size
,
self
.
chunck_size
,
device
=
hidden_states
.
device
),
diagonal
=
1
)
\
[:,:
min
(
self
.
chunck_size
,
min
(
past_key_value
.
max_cache_len
-
cur_idx
,
self
.
chunck_size
))]
self
.
attn_mask
[:,
:,
:,
cur_idx
+
self
.
chunck_size
:]
=
-
1e+38
self
.
attn_mask
[:,
:,
:,
:
cur_idx
]
=
0
chunk_mask
=
torch
.
narrow
(
self
.
attn_mask
,
2
,
0
,
min
(
self
.
chunck_size
,
q_len
-
cur_idx
))
cur_output
,
_
,
_
=
self
.
forward_chunck
(
hidden_states
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...],
chunk_mask
,
position_ids
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
past_key_value
,
output_attentions
,
use_cache
,
cache_position
[
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
**
kwargs
)
cur_idx
+=
self
.
chunck_size
if
attn_output
is
None
:
attn_output
=
cur_output
else
:
attn_output
=
torch
.
cat
((
attn_output
,
cur_output
),
dim
=-
2
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
rotate_half
(
x
):
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
...
@@ -423,8 +297,6 @@ def rotate_half(x):
...
@@ -423,8 +297,6 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
class
KLlamaAttention
(
BaseInjectedModule
):
class
KLlamaAttention
(
BaseInjectedModule
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
...
ktransformers/operators/triton_attention.py
0 → 100644
View file @
bb35dc5b
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
_fwd_grouped_kernel_stage1
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_kv_id
=
tl
.
program_id
(
2
)
if
kv_group_num
>
BLOCK_H
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:])
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
kv_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
])
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
])
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
))
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:])
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
[
None
,
:])
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
)
def
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_Seqlen
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
):
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
#if is_hip_ and Lk >= 576:
# BLOCK = 16
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
BLOCK_H
=
16
NUM_KV_SPLITS
=
num_kv_splits
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
NUM_KV_SPLITS
,
)
extra_kargs
=
{}
# TODO: support hip
"""
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
"""
_fwd_grouped_kernel_stage1
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_stages
=
2
,
Lk
=
Lk
,
Lv
=
Lv
,
**
extra_kargs
,
)
@
triton
.
jit
def
_fwd_kernel_stage2
(
Mid_O
,
o
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
if
split_kv_end
>
split_kv_start
:
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
o
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
# TODO: support hip
"""
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
"""
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
4
,
num_stages
=
2
,
**
extra_kargs
,
)
def
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_seq_len
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
\ No newline at end of file
ktransformers/util/utils.py
View file @
bb35dc5b
...
@@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)
)
else
:
else
:
past_key_values
=
None
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
long
)
generated_ids
=
torch
.
zeros
(
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
)
...
@@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids
[:,
seq_length
]
=
next_token
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
tokens
.
append
(
int
(
next_token
))
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
,
dtype
=
torch
.
long
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
seq_length
+=
1
seq_length
+=
1
...
...
test_prompt.txt
0 → 100644
View file @
bb35dc5b
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment