Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
bb35dc5b
Commit
bb35dc5b
authored
Feb 13, 2025
by
Atream
Browse files
init support for MLA using Attention kernel
parent
62011fd6
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
531 additions
and
242 deletions
+531
-242
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+37
-8
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+104
-232
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+379
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+2
-2
test_prompt.txt
test_prompt.txt
+9
-0
No files found.
ktransformers/models/custom_cache.py
View file @
bb35dc5b
...
@@ -51,13 +51,33 @@ class StaticCache(transformers.StaticCache):
...
@@ -51,13 +51,33 @@ class StaticCache(transformers.StaticCache):
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
if
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
self
.
page_size
=
64
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
self
.
max_pages
=
(
self
.
max_cache_len
+
self
.
page_size
-
1
)
//
self
.
page_size
key_shape
=
(
max_batch_size
,
1
,
self
.
max_cache_len
,
config
.
qk_rope_head_dim
)
key_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
qk_rope_head_dim
)
value_shape
=
(
max_batch_size
,
1
,
self
.
max_cache_len
,
config
.
kv_lora_rank
)
value_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
kv_lora_rank
)
# TODO: support real page table
self
.
page_table_map
=
dict
()
self
.
page_table_list
=
[]
for
idx
in
range
(
config
.
num_hidden_layers
):
if
isinstance
(
device
,
dict
):
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
else
:
target_device
=
device
if
target_device
not
in
self
.
page_table_map
:
page_table
=
torch
.
zeros
((
max_batch_size
,
self
.
max_pages
),
dtype
=
torch
.
int32
,
device
=
target_device
)
for
seq_id
in
range
(
max_batch_size
):
page_table
[
seq_id
,
:]
=
torch
.
arange
(
seq_id
*
self
.
max_pages
,
seq_id
*
self
.
max_pages
+
self
.
max_pages
,
dtype
=
torch
.
int32
,
device
=
target_device
)
self
.
page_table_map
[
target_device
]
=
page_table
self
.
page_table_list
.
append
(
self
.
page_table_map
[
target_device
])
self
.
is_MLA
=
True
self
.
is_page
=
True
else
:
else
:
key_shape
=
cache_shape
key_shape
=
cache_shape
value_shape
=
cache_shape
value_shape
=
cache_shape
self
.
is_MLA
=
False
self
.
past_tokens
=
[]
self
.
past_tokens
=
[]
self
.
num_hidden_layers
=
config
.
num_hidden_layers
self
.
num_hidden_layers
=
config
.
num_hidden_layers
...
@@ -104,10 +124,19 @@ class StaticCache(transformers.StaticCache):
...
@@ -104,10 +124,19 @@ class StaticCache(transformers.StaticCache):
cache_position
=
cache_kwargs
.
get
(
"cache_position"
)
cache_position
=
cache_kwargs
.
get
(
"cache_position"
)
k_out
=
self
.
key_cache
[
layer_idx
]
k_out
=
self
.
key_cache
[
layer_idx
]
v_out
=
self
.
value_cache
[
layer_idx
]
v_out
=
self
.
value_cache
[
layer_idx
]
self
.
past_tokens
[
layer_idx
]
+=
cache_position
.
size
(
0
)
#print(cache_position)
#print(cache_position)
if
self
.
is_MLA
:
page_idx
=
cache_position
//
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out
[
page_idx
,
page_offset
,
...]
=
key_states
v_out
[
page_idx
,
page_offset
,
...]
=
value_states
return
k_out
,
v_out
,
self
.
page_table_list
[
layer_idx
]
else
:
k_out
[:,
:,
cache_position
]
=
key_states
k_out
[:,
:,
cache_position
]
=
key_states
v_out
[:,
:,
cache_position
]
=
value_states
v_out
[:,
:,
cache_position
]
=
value_states
self
.
past_tokens
[
layer_idx
]
+=
cache_position
.
size
(
0
)
return
k_out
,
v_out
return
k_out
,
v_out
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
...
...
ktransformers/operators/attention.py
View file @
bb35dc5b
...
@@ -21,9 +21,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
...
@@ -21,9 +21,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
import
logging
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_with_kvcache
,
flash_attn_func
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
logger
=
logging
.
getLogger
(
"attention"
)
logger
=
logging
.
getLogger
(
"attention"
)
class
KDeepseekV3Attention
(
BaseInjectedModule
,
DeepseekV3Attention
):
# V3 MLA is same to V2
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -39,7 +43,6 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -39,7 +43,6 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
self
.
orig_module
.
__init__
(
orig_module
.
config
,
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
softmax_scale
=
self
.
q_head_dim
**
(
-
0.5
)
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
...
@@ -52,7 +55,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -52,7 +55,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
self
.
out_absorb
.
weight
.
data
=
out_absorb
del
self
.
orig_module
.
kv_b_proj
#
del self.orig_module.kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
return
q_absorb
,
out_absorb
...
@@ -96,7 +99,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -96,7 +99,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
_v3
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
...
@@ -151,7 +154,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -151,7 +154,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
attn_weights
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward
(
def
forward
(
self
,
self
,
...
@@ -164,109 +167,9 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -164,109 +167,9 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
"padding_mask"
in
kwargs
:
warnings
.
warn
(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
q_len
<=
self
.
chunck_size
:
return
self
.
forward_chunck
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
)
assert
output_attentions
==
False
,
"output_attentions is not supported when using chunked attention"
attn_output
=
None
attn_weight
=
None
cur_idx
=
0
while
cur_idx
<
q_len
:
if
attention_mask
is
not
None
:
chunk_mask
=
attention_mask
[:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...]
else
:
# generate chunk_mask automatically.
self
.
attn_mask
=
\
torch
.
zeros
(
1
,
1
,
self
.
chunck_size
,
past_key_value
.
max_cache_len
,
device
=
hidden_states
.
device
)
\
if
self
.
attn_mask
is
None
\
else
self
.
attn_mask
self
.
attn_mask
[:,
:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
past_key_value
.
max_cache_len
)]
=
\
-
1e+38
*
torch
.
triu
(
torch
.
ones
(
self
.
chunck_size
,
self
.
chunck_size
,
device
=
hidden_states
.
device
),
diagonal
=
1
)
\
[:,:
min
(
self
.
chunck_size
,
min
(
past_key_value
.
max_cache_len
-
cur_idx
,
self
.
chunck_size
))]
self
.
attn_mask
[:,
:,
:,
cur_idx
+
self
.
chunck_size
:]
=
-
1e+38
self
.
attn_mask
[:,
:,
:,
:
cur_idx
]
=
0
chunk_mask
=
torch
.
narrow
(
self
.
attn_mask
,
2
,
0
,
min
(
self
.
chunck_size
,
q_len
-
cur_idx
))
cur_output
,
cur_attn_weight
=
self
.
forward_chunck
(
hidden_states
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...],
chunk_mask
,
position_ids
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
past_key_value
,
output_attentions
,
use_cache
,
cache_position
[
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
**
kwargs
)
cur_idx
+=
self
.
chunck_size
if
attn_output
is
None
:
attn_output
=
cur_output
attn_weight
=
cur_attn_weight
else
:
attn_output
=
torch
.
cat
((
attn_output
,
cur_output
),
dim
=-
2
)
attn_weight
=
torch
.
cat
((
attn_weight
,
cur_attn_weight
),
dim
=-
2
)
return
attn_output
,
attn_weight
,
past_key_value
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
qk_nope_head_dim
,
bias
=
False
,
dtype
=
q_absorb
.
dtype
,
device
=
q_absorb
.
device
)
self
.
q_absorb
.
weight
.
data
=
q_absorb
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
del
self
.
orig_module
.
kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
def
forward_chunck
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
q_lora_rank
is
None
:
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
)
q
=
self
.
q_proj
(
hidden_states
)
else
:
else
:
...
@@ -282,140 +185,111 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -282,140 +185,111 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
)
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
).
transpose
(
1
,
2
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
).
transpose
(
1
,
2
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
kv_seq_len
=
k_pe
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
k_pe
=
k_pe
.
transpose
(
1
,
2
)
# [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if
q_len
==
1
:
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv
=
compressed_kv
.
unsqueeze
(
1
)
k_pe
,
compressed_kv
,
page_table
=
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
,
compressed_kv
=
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
=
compressed_kv
.
squeeze
(
1
)
#if cache_position is not None:
# compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:]
# k_pe = k_pe[:,:,: cache_position[-1] + 1,:]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
attn_weights
=
(
torch
.
matmul
(
q_pe
,
k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
compressed_kv
.
unsqueeze
(
-
3
).
mT
))
*
self
.
softmax_scale
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
"""
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
raise ValueError(
# q_nope [bsz, self.num_heads, q_len, self.kv_lora_rank]
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
f" {attn_weights.size()}"
query_states
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
# k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# compressed_kv [bsz, q_len, 1, self.kv_lora_rank]
key_states
=
torch
.
cat
([
compressed_kv
,
k_pe
],
dim
=-
1
)
query_states
=
query_states
.
squeeze
(
2
)
attn_output
=
torch
.
zeros_like
(
q_nope
)
attn_logits
=
torch
.
empty
(
(
bsz
,
self
.
num_heads
,
1
,
#num_kv_splits # follow vLLM, fix it TODO
self
.
kv_lora_rank
+
1
,
),
dtype
=
torch
.
float32
,
device
=
attn_output
.
device
)
)
assert attention_mask is not None
"""
"""
if
attention_mask
is
not
None
:
print("query_states", torch.isnan(query_states).any())
print("key_states", torch.isnan(key_states[:,:,0,:]).any())
print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any())
print("position_ids", torch.isnan(position_ids).any())
"""
"""
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
"""
#causal_mask = attention_mask[:, :, :, : kv_seq_len]
attn_weights
=
attn_weights
+
attention_mask
# upc
as
t
att
ention to fp32
# fl
as
h
att
n doesn't support head_dim bigger than 256
attn_weights
=
nn
.
functional
.
softmax
(
# use vLLM triton attention kernel for MQA
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
decode_attention_fwd_grouped
(
query_states
,
key_states
,
compressed_kv
,
attn_output
,
).
to
(
q_pe
.
dtype
)
page_table
,
attn_weights
=
nn
.
functional
.
dropout
(
position_ids
.
squeeze
(
0
).
to
(
torch
.
int32
),
attn_logits
,
attn_weights
,
p
=
self
.
attention_dropout
,
training
=
self
.
training
1
,
#num_kv_splits # follow vLLM, fix it TODO
)
self
.
softmax_scale
,
attn_output
=
torch
.
einsum
(
'bhql,blc->bhqc'
,
attn_weights
,
compressed_kv
)
past_key_value
.
page_size
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
v_head_dim
):
raise
ValueError
(
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
v_head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
#print("attn_output", torch.isnan(attn_output).any())
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
else
:
def
forward
(
if
past_key_value
is
not
None
:
self
,
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
hidden_states
:
torch
.
Tensor
,
k_pe
.
squeeze
(
0
)
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
compressed_kv
.
squeeze
(
0
)
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
past_key_value
:
Optional
[
Cache
]
=
None
,
k_pe
.
unsqueeze
(
0
)
output_attentions
:
bool
=
False
,
compressed_kv
.
unsqueeze
(
0
)
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
k_pe
=
k_pe
[:,
:
q_len
]
**
kwargs
,
compressed_kv
=
compressed_kv
[:,
:
q_len
]
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
kv
=
(
if
"padding_mask"
in
kwargs
:
self
.
kv_b_proj
(
compressed_kv
)
warnings
.
warn
(
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
q_head_dim
)
if
q_len
<=
self
.
chunck_size
:
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
return
self
.
forward_chunck
(
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
hidden_states
,
attention_mask
,
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
position_ids
,
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
past_key_value
,
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
output_attentions
,
use_cache
,
query_states
=
query_states
.
transpose
(
1
,
2
)
cache_position
,
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
**
kwargs
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states_padded
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
)
)
assert
output_attentions
==
False
,
"output_attentions is not supported when using chunked attention"
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
None
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
cur_idx
=
0
while
cur_idx
<
q_len
:
if
attention_mask
is
not
None
:
chunk_mask
=
attention_mask
[:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...]
else
:
# generate chunk_mask automatically.
self
.
attn_mask
=
\
torch
.
zeros
(
1
,
1
,
self
.
chunck_size
,
past_key_value
.
max_cache_len
,
device
=
hidden_states
.
device
)
\
if
self
.
attn_mask
is
None
\
else
self
.
attn_mask
self
.
attn_mask
[:,
:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
past_key_value
.
max_cache_len
)]
=
\
-
1e+38
*
torch
.
triu
(
torch
.
ones
(
self
.
chunck_size
,
self
.
chunck_size
,
device
=
hidden_states
.
device
),
diagonal
=
1
)
\
[:,:
min
(
self
.
chunck_size
,
min
(
past_key_value
.
max_cache_len
-
cur_idx
,
self
.
chunck_size
))]
self
.
attn_mask
[:,
:,
:,
cur_idx
+
self
.
chunck_size
:]
=
-
1e+38
self
.
attn_mask
[:,
:,
:,
:
cur_idx
]
=
0
chunk_mask
=
torch
.
narrow
(
self
.
attn_mask
,
2
,
0
,
min
(
self
.
chunck_size
,
q_len
-
cur_idx
))
cur_output
,
_
,
_
=
self
.
forward_chunck
(
hidden_states
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...],
chunk_mask
,
position_ids
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
past_key_value
,
output_attentions
,
use_cache
,
cache_position
[
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
**
kwargs
)
cur_idx
+=
self
.
chunck_size
if
attn_output
is
None
:
attn_output
=
cur_output
else
:
attn_output
=
torch
.
cat
((
attn_output
,
cur_output
),
dim
=-
2
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
rotate_half
(
x
):
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
...
@@ -423,8 +297,6 @@ def rotate_half(x):
...
@@ -423,8 +297,6 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
class
KLlamaAttention
(
BaseInjectedModule
):
class
KLlamaAttention
(
BaseInjectedModule
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
...
ktransformers/operators/triton_attention.py
0 → 100644
View file @
bb35dc5b
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
_fwd_grouped_kernel_stage1
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_kv_id
=
tl
.
program_id
(
2
)
if
kv_group_num
>
BLOCK_H
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:])
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
kv_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
])
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
])
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
))
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:])
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
[
None
,
:])
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
)
def
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_Seqlen
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
):
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
#if is_hip_ and Lk >= 576:
# BLOCK = 16
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
BLOCK_H
=
16
NUM_KV_SPLITS
=
num_kv_splits
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
NUM_KV_SPLITS
,
)
extra_kargs
=
{}
# TODO: support hip
"""
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
"""
_fwd_grouped_kernel_stage1
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_stages
=
2
,
Lk
=
Lk
,
Lv
=
Lv
,
**
extra_kargs
,
)
@
triton
.
jit
def
_fwd_kernel_stage2
(
Mid_O
,
o
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
if
split_kv_end
>
split_kv_start
:
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
o
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
# TODO: support hip
"""
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
"""
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
4
,
num_stages
=
2
,
**
extra_kargs
,
)
def
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_seq_len
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
\ No newline at end of file
ktransformers/util/utils.py
View file @
bb35dc5b
...
@@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)
)
else
:
else
:
past_key_values
=
None
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
long
)
generated_ids
=
torch
.
zeros
(
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
)
...
@@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids
[:,
seq_length
]
=
next_token
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
tokens
.
append
(
int
(
next_token
))
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
,
dtype
=
torch
.
long
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
seq_length
+=
1
seq_length
+=
1
...
...
test_prompt.txt
0 → 100644
View file @
bb35dc5b
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment