Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
1084d4e4
Commit
1084d4e4
authored
Feb 14, 2025
by
Atream
Browse files
linux support triton MLA kernel
parent
bb35dc5b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
198 additions
and
61 deletions
+198
-61
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+18
-9
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+180
-52
No files found.
ktransformers/models/custom_cache.py
View file @
1084d4e4
...
@@ -53,8 +53,9 @@ class StaticCache(transformers.StaticCache):
...
@@ -53,8 +53,9 @@ class StaticCache(transformers.StaticCache):
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
self
.
page_size
=
64
self
.
page_size
=
64
self
.
max_pages
=
(
self
.
max_cache_len
+
self
.
page_size
-
1
)
//
self
.
page_size
self
.
max_pages
=
(
self
.
max_cache_len
+
self
.
page_size
-
1
)
//
self
.
page_size
key_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
qk_rope_head_dim
)
latent_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
kv_lora_rank
+
config
.
qk_rope_head_dim
)
value_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
kv_lora_rank
)
self
.
kv_lora_rank
=
config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
config
.
qk_rope_head_dim
# TODO: support real page table
# TODO: support real page table
self
.
page_table_map
=
dict
()
self
.
page_table_map
=
dict
()
self
.
page_table_list
=
[]
self
.
page_table_list
=
[]
...
@@ -88,10 +89,17 @@ class StaticCache(transformers.StaticCache):
...
@@ -88,10 +89,17 @@ class StaticCache(transformers.StaticCache):
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
else
:
else
:
target_device
=
device
target_device
=
device
new_layer_key_cache
=
torch
.
zeros
(
key_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
new_layer_value_cache
=
torch
.
zeros
(
value_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
if
self
.
is_MLA
:
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
new_layer_key_cache
=
torch
.
zeros
(
latent_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_value_cache
)
new_layer_value_cache
=
None
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
else
:
new_layer_key_cache
=
torch
.
zeros
(
key_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
new_layer_value_cache
=
torch
.
zeros
(
value_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_value_cache
)
self
.
key_cache
.
append
(
new_layer_key_cache
)
self
.
key_cache
.
append
(
new_layer_key_cache
)
self
.
value_cache
.
append
(
new_layer_value_cache
)
self
.
value_cache
.
append
(
new_layer_value_cache
)
self
.
past_tokens
.
append
(
0
)
self
.
past_tokens
.
append
(
0
)
...
@@ -129,11 +137,12 @@ class StaticCache(transformers.StaticCache):
...
@@ -129,11 +137,12 @@ class StaticCache(transformers.StaticCache):
if
self
.
is_MLA
:
if
self
.
is_MLA
:
page_idx
=
cache_position
//
self
.
page_size
page_idx
=
cache_position
//
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
#print("page_idx", page_idx)
#print("page_idx", page_idx)
#print("page_offset", page_offset)
#print("page_offset", page_offset)
k_out
[
page_idx
,
page_offset
,
...
]
=
key_states
k_out
[
page_idx
,
page_offset
,
:,
:
self
.
kv_lora_rank
]
=
key_states
v
_out
[
page_idx
,
page_offset
,
...
]
=
value_states
k
_out
[
page_idx
,
page_offset
,
:,
self
.
kv_lora_rank
:
]
=
value_states
return
k_out
,
v_out
,
self
.
page_table_list
[
layer_idx
]
return
k_out
,
self
.
page_table_list
[
layer_idx
]
else
:
else
:
k_out
[:,
:,
cache_position
]
=
key_states
k_out
[:,
:,
cache_position
]
=
key_states
v_out
[:,
:,
cache_position
]
=
value_states
v_out
[:,
:,
cache_position
]
=
value_states
...
...
ktransformers/operators/attention.py
View file @
1084d4e4
...
@@ -13,8 +13,6 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
...
@@ -13,8 +13,6 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3Attention
from
ktransformers.models.modeling_deepseek_v3
import
apply_rotary_pos_emb
as
apply_rotary_pos_emb_v3
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
...
@@ -23,8 +21,15 @@ from transformers.configuration_utils import PretrainedConfig
...
@@ -23,8 +21,15 @@ from transformers.configuration_utils import PretrainedConfig
from
transformers.cache_utils
import
Cache
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_with_kvcache
,
flash_attn_func
from
flash_attn
import
flash_attn_with_kvcache
,
flash_attn_func
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
import
os
logger
=
logging
.
getLogger
(
"attention"
)
logger
=
logging
.
getLogger
(
"attention"
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
# V3 MLA is same to V2
# V3 MLA is same to V2
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
...
@@ -80,6 +85,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -80,6 +85,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope
,
q_pe
=
torch
.
split
(
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
)
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
k_pe
=
torch
.
split
(
...
@@ -103,16 +110,37 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -103,16 +110,37 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv
=
compressed_kv
.
unsqueeze
(
1
)
k_pe
,
compressed_kv
=
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
# compressed_kv [bsz, q_len, self.kv_lora_rank]
compressed_kv
=
compressed_kv
.
squeeze
(
1
)
# k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
#if cache_position is not None:
k_pe
=
k_pe
.
transpose
(
1
,
2
)
# compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:]
compressed_kv
=
compressed_kv
.
unsqueeze
(
2
)
# k_pe = k_pe[:,:,: cache_position[-1] + 1,:]
compressed_kv_with_k_pe
,
_
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv_with_k_pe
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# k_pe [pages, page_size, 1, self.qk_rope_head_dim]
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
if
hasattr
(
self
.
orig_module
,
'kv_b_proj'
):
del
self
.
orig_module
.
kv_b_proj
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
k_pe
=
k_pe
.
view
(
bsz
,
1
,
-
1
,
self
.
qk_rope_head_dim
)[:,:,:
attention_mask
.
size
(
-
1
),:]
compressed_kv
=
compressed_kv
.
view
(
bsz
,
1
,
-
1
,
self
.
kv_lora_rank
)[:,:,:
attention_mask
.
size
(
-
1
),:]
# k_pe [bsz, 1, cache_len, self.qk_rope_head_dim]
# compressed_kv [bsz, 1, cache_len,self.kv_lora_rank]
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
attn_weights
=
(
torch
.
matmul
(
q_pe
,
k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
compressed_kv
.
unsqueeze
(
-
3
).
mT
))
*
self
.
softmax_scale
#print(q_pe.shape)
#print(k_pe.shape)
#print(q_nope.shape)
#print(compressed_kv.shape)
attn_weights
=
(
torch
.
matmul
(
q_pe
,
k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
compressed_kv
.
mT
))
*
self
.
softmax_scale
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
compressed_kv
=
compressed_kv
.
squeeze
(
1
)
"""
"""
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
raise ValueError(
...
@@ -156,25 +184,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -156,25 +184,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward
(
def
forward
_linux
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
q_lora_rank
is
None
:
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
)
q
=
self
.
q_proj
(
hidden_states
)
else
:
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
)))
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
)))
q
=
q
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
.
transpose
(
1
,
2
)
q
=
q
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
q_nope
,
q_pe
=
torch
.
split
(
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
)
...
@@ -184,38 +212,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -184,38 +212,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
.
transpose
(
1
,
2
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
k
_pe
=
k_pe
.
transpose
(
1
,
2
)
#
[bsz, q_len, 1, self.qk_rope_head_dim]
# q
_pe
[bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe
[bsz, q_len, 1, self.qk_rope_head_dim]
# decode
# decode
if
q_len
==
1
:
if
q_len
==
1
:
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
,
compressed_kv
,
page_table
=
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv_with_k_pe
,
page_table
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
=
compressed_kv_with_k_pe
[:,
:,
:,
:
self
.
kv_lora_rank
]
# for speed
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim]
# compressed_kv [bsz, q_len, 1, self.kv_lora_rank]
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
# q_nope [bsz, self.num_heads, q_len, self.kv_lora_rank]
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
assert
q_nope
.
is_contiguous
()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
query_states
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
query_states
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
# k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# compressed_kv [bsz, q_len, 1, self.kv_lora_rank]
key_states
=
torch
.
cat
([
compressed_kv
,
k_pe
],
dim
=-
1
)
query_states
=
query_states
.
squeeze
(
2
)
query_states
=
query_states
.
squeeze
(
1
)
attn_output
=
torch
.
zeros_like
(
q_nope
)
attn_output
=
torch
.
zeros_like
(
q_nope
)
# [bsz, q_len, self.num_heads, self.kv_lora_rank]
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
(
bsz
,
bsz
,
self
.
num_heads
,
self
.
num_heads
,
1
,
#num_kv_splits # follow vLLM, fix it TODO
4
,
#num_kv_splits # follow vLLM, fix it TODO
self
.
kv_lora_rank
+
1
,
self
.
kv_lora_rank
+
1
,
),
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -224,22 +256,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -224,22 +256,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""
"""
print("query_states", torch.isnan(query_states).any())
print("query_states", torch.isnan(query_states).any())
print("
key_states", torch.isnan(key_states
[:,:,0,:]).any())
print("
compressed_kv_with_k_pe", torch.isnan(compressed_kv_with_k_pe
[:,:,0,:]).any())
print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any())
print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any())
print("position_ids", torch.isnan(position_ids).any())
print("position_ids", torch.isnan(position_ids).any())
"""
"""
# flash attn doesn't support head_dim bigger than 256
# flash attn doesn't support head_dim bigger than 256
# use vLLM triton attention kernel for MQA
# use vLLM triton attention kernel for MQA
decode_attention_fwd_grouped
(
query_states
,
key_states
,
compressed_kv
,
attn_output
,
decode_attention_fwd_grouped
(
query_states
,
compressed_kv_with_k_pe
,
compressed_kv
,
attn_output
,
page_table
,
page_table
,
position_ids
.
squeeze
(
0
).
to
(
torch
.
int32
),
attn_logits
,
position_ids
.
squeeze
(
0
).
to
(
torch
.
int32
),
attn_logits
,
1
,
#num_kv_splits # follow vLLM, fix it TODO
4
,
#num_kv_splits # follow vLLM, fix it TODO
self
.
softmax_scale
,
self
.
softmax_scale
,
past_key_value
.
page_size
)
past_key_value
.
page_size
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
@@ -250,7 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -250,7 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
k_pe
,
compressed_kv
,
self
.
layer_idx
,
cache_kwargs
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
.
unsqueeze
(
0
)
compressed_kv
.
unsqueeze
(
0
)
...
@@ -261,7 +296,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -261,7 +296,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
q_head_dim
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
...
@@ -269,7 +304,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -269,7 +304,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
query_states
=
query_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
...
@@ -289,12 +323,106 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -289,12 +323,106 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
).
contiguous
()
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward_windows
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
"padding_mask"
in
kwargs
:
warnings
.
warn
(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
def
rotate_half
(
x
):
if
q_len
<=
self
.
chunck_size
:
"""Rotates half the hidden dims of the input."""
return
self
.
forward_chunck
(
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
hidden_states
,
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
attention_mask
,
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
)
assert
output_attentions
==
False
,
"output_attentions is not supported when using chunked attention"
attn_output
=
None
cur_idx
=
0
while
cur_idx
<
q_len
:
if
attention_mask
is
not
None
:
chunk_mask
=
attention_mask
[:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...]
else
:
# generate chunk_mask automatically.
self
.
attn_mask
=
\
torch
.
zeros
(
1
,
1
,
self
.
chunck_size
,
past_key_value
.
max_cache_len
,
device
=
hidden_states
.
device
)
\
if
self
.
attn_mask
is
None
\
else
self
.
attn_mask
self
.
attn_mask
[:,
:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
past_key_value
.
max_cache_len
)]
=
\
-
1e+38
*
torch
.
triu
(
torch
.
ones
(
self
.
chunck_size
,
self
.
chunck_size
,
device
=
hidden_states
.
device
),
diagonal
=
1
)
\
[:,:
min
(
self
.
chunck_size
,
min
(
past_key_value
.
max_cache_len
-
cur_idx
,
self
.
chunck_size
))]
self
.
attn_mask
[:,
:,
:,
cur_idx
+
self
.
chunck_size
:]
=
-
1e+38
self
.
attn_mask
[:,
:,
:,
:
cur_idx
]
=
0
chunk_mask
=
torch
.
narrow
(
self
.
attn_mask
,
2
,
0
,
min
(
self
.
chunck_size
,
q_len
-
cur_idx
))
cur_output
,
_
,
_
=
self
.
forward_chunck
(
hidden_states
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...],
chunk_mask
,
position_ids
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
past_key_value
,
output_attentions
,
use_cache
,
cache_position
[
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
**
kwargs
)
cur_idx
+=
self
.
chunck_size
if
attn_output
is
None
:
attn_output
=
cur_output
else
:
attn_output
=
torch
.
cat
((
attn_output
,
cur_output
),
dim
=-
2
)
return
attn_output
,
None
,
past_key_value
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
:
return
self
.
forward_windows
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
,
)
else
:
return
self
.
forward_linux
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
,
)
class
KLlamaAttention
(
BaseInjectedModule
):
class
KLlamaAttention
(
BaseInjectedModule
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment