Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
4d1d561d
Commit
4d1d561d
authored
Aug 28, 2024
by
chenxl
Browse files
[feature] release 0.1.3
parent
67f8b370
Changes
58
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
2466 additions
and
105 deletions
+2466
-105
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+163
-10
ktransformers/operators/cpuinfer.py
ktransformers/operators/cpuinfer.py
+737
-9
ktransformers/operators/dynamic_attention.py
ktransformers/operators/dynamic_attention.py
+775
-0
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+2
-2
ktransformers/operators/models.py
ktransformers/operators/models.py
+694
-62
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
...optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
+1
-1
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
+1
-1
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
+7
-1
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
...imize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
+1
-1
ktransformers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml
...rmers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml
+28
-0
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
...ize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
+1
-1
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
...mers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
+8
-1
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+24
-4
ktransformers/util/cuda_graph_runner.py
ktransformers/util/cuda_graph_runner.py
+2
-1
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+1
-2
ktransformers/util/utils.py
ktransformers/util/utils.py
+17
-7
pyproject.toml
pyproject.toml
+2
-1
requirements-local_chat.txt
requirements-local_chat.txt
+2
-1
No files found.
ktransformers/operators/attention.py
View file @
4d1d561d
...
...
@@ -7,16 +7,22 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import
torch
from
torch
import
nn
import
warnings
import
torch.nn.functional
as
F
from
ktransformers.operators.models
import
KLlamaModel
from
ktransformers.models.configuration_deepseek
import
DeepseekV2Config
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
logger
=
logging
.
getLogger
(
"attention"
)
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
def
__init__
(
self
,
key
:
str
,
...
...
@@ -24,10 +30,12 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
...
...
@@ -157,9 +165,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
chunck_size
=
256
# TODO, generate chunck_size automatically.
if
q_len
<=
chunck_size
:
if
q_len
<=
self
.
chunck_size
:
return
self
.
forward_chunck
(
hidden_states
,
attention_mask
,
...
...
@@ -176,24 +183,170 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cur_idx
=
0
while
cur_idx
<
q_len
:
if
attention_mask
is
not
None
:
chunk_mask
=
attention_mask
[:,
:,
cur_idx
:
min
(
cur_idx
+
chunck_size
,
q_len
),
...]
chunk_mask
=
attention_mask
[:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...]
else
:
chunk_mask
=
None
# generate chunk_mask automatically.
self
.
attn_mask
=
\
torch
.
zeros
(
1
,
1
,
self
.
chunck_size
,
past_key_value
.
max_cache_len
,
device
=
hidden_states
.
device
)
\
if
self
.
attn_mask
is
None
\
else
self
.
attn_mask
self
.
attn_mask
[:,
:,
:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
past_key_value
.
max_cache_len
)]
=
\
-
1e+38
*
torch
.
triu
(
torch
.
ones
(
self
.
chunck_size
,
self
.
chunck_size
,
device
=
hidden_states
.
device
),
diagonal
=
1
)
\
[:,:
min
(
self
.
chunck_size
,
min
(
past_key_value
.
max_cache_len
-
cur_idx
,
self
.
chunck_size
))]
self
.
attn_mask
[:,
:,
:,
cur_idx
+
self
.
chunck_size
:]
=
-
1e+38
self
.
attn_mask
[:,
:,
:,
:
cur_idx
]
=
0
chunck_mask
=
torch
.
narrow
(
self
.
attn_mask
,
2
,
0
,
min
(
self
.
chunck_size
,
q_len
-
cur_idx
))
cur_output
,
_
,
_
=
self
.
forward_chunck
(
hidden_states
[:,
cur_idx
:
min
(
cur_idx
+
chunck_size
,
q_len
),
...],
chunk_mask
,
position_ids
[:,
cur_idx
:
min
(
cur_idx
+
chunck_size
,
q_len
)],
hidden_states
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
),
...],
chun
c
k_mask
,
position_ids
[:,
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
past_key_value
,
output_attentions
,
use_cache
,
cache_position
[
cur_idx
:
min
(
cur_idx
+
chunck_size
,
q_len
)],
cache_position
[
cur_idx
:
min
(
cur_idx
+
self
.
chunck_size
,
q_len
)],
**
kwargs
)
cur_idx
+=
chunck_size
cur_idx
+=
self
.
chunck_size
if
attn_output
is
None
:
attn_output
=
cur_output
else
:
attn_output
=
torch
.
cat
((
attn_output
,
cur_output
),
dim
=-
2
)
return
attn_output
,
None
,
past_key_value
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
class
KLlamaAttention
(
BaseInjectedModule
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
def
apply_rotary_pos_emb
(
self
,
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_embeddings
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
# will become mandatory in v4.45
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
config
.
pretraining_tp
>
1
:
key_value_slicing
=
(
self
.
num_key_value_heads
*
self
.
head_dim
)
//
self
.
config
.
pretraining_tp
query_slices
=
self
.
q_proj
.
weight
.
split
(
(
self
.
num_heads
*
self
.
head_dim
)
//
self
.
config
.
pretraining_tp
,
dim
=
0
)
key_slices
=
self
.
k_proj
.
weight
.
split
(
key_value_slicing
,
dim
=
0
)
value_slices
=
self
.
v_proj
.
weight
.
split
(
key_value_slicing
,
dim
=
0
)
query_states
=
[
F
.
linear
(
hidden_states
,
query_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)]
query_states
=
torch
.
cat
(
query_states
,
dim
=-
1
)
key_states
=
[
F
.
linear
(
hidden_states
,
key_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)]
key_states
=
torch
.
cat
(
key_states
,
dim
=-
1
)
value_states
=
[
F
.
linear
(
hidden_states
,
value_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)]
value_states
=
torch
.
cat
(
value_states
,
dim
=-
1
)
else
:
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
if
position_embeddings
is
None
:
logger
.
warning
(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
else
:
cos
,
sin
=
position_embeddings
query_states
,
key_states
=
self
.
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
if
q_len
==
1
:
position_ids
=
position_ids
[
0
][
-
1
].
unsqueeze
(
0
).
unsqueeze
(
0
)
query_states
=
query_states
[:,
:,
-
1
:]
key_states
=
key_states
[:,
:,
-
1
:]
attn_output
=
KLlamaModel
.
dynamic_sdpa
.
apply
(
self
.
layer_idx
,
bsz
,
position_ids
[
0
][
0
],
query_states
.
transpose
(
1
,
2
).
to
(
torch
.
float16
),
key_states
.
transpose
(
1
,
2
).
to
(
torch
.
float16
),
value_states
.
transpose
(
1
,
2
).
to
(
torch
.
float16
),
mode
=
"prefill"
if
q_len
>
1
else
"generate"
,
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
):
raise
ValueError
(
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
)
if
self
.
config
.
pretraining_tp
>
1
:
attn_output
=
attn_output
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
2
)
o_proj_slices
=
self
.
o_proj
.
weight
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
1
)
attn_output
=
sum
([
F
.
linear
(
attn_output
[
i
],
o_proj_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)])
else
:
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
\ No newline at end of file
ktransformers/operators/cpuinfer.py
View file @
4d1d561d
#!/usr/bin/env python
# coding=utf-8
"""
Description : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference
with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring
and managing key-value caches, updating and retrieving cache data, and handling attention
operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies
(e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization
on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration.
These classes facilitate efficient caching and memory management for deep learning models
that leverage key-value attention mechanisms, particularly on CPU-based systems.
Author : djw
Date : 2024-08-26 23:25:24
Version : 1.0.0
LastEditors : djw
LastEditTime : 2024-08-26 23:25:24
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import
sys
,
os
from
typing
import
Any
import
torch
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Debug"
))
import
cpuinfer_ext
from
ktransformers.server.config.config
import
Config
class
CPUInferKVCache
:
def
__init__
(
self
,
layer_num
:
int
=
32
,
kv_head_num
:
int
=
8
,
q_head_num
:
int
=
32
,
head_dim
:
int
=
128
,
block_len
:
int
=
256
,
anchor_num
:
int
=
4
,
anchor_type
:
str
=
"FIXED"
,
kv_type
:
str
=
"Q4_0"
,
retrieval_type
:
str
=
"SHARED"
,
layer_step
:
int
=
1
,
token_step
:
int
=
1
,
layer_offset
:
int
=
0
,
max_thread_num
:
int
=
32
,
max_batch_size
:
int
=
4
,
max_block_num
:
int
=
512
,
):
if
anchor_type
==
"FIXED"
:
anchor_type
=
cpuinfer_ext
.
kvcache
.
AnchorType
.
FIXED
elif
anchor_type
==
"QUEST"
:
anchor_type
=
cpuinfer_ext
.
kvcache
.
AnchorType
.
QUEST
elif
anchor_type
==
"DYNAMIC"
:
anchor_type
=
cpuinfer_ext
.
kvcache
.
AnchorType
.
DYNAMIC
elif
anchor_type
==
"BLOCK_MEAN"
:
anchor_type
=
cpuinfer_ext
.
kvcache
.
AnchorType
.
BLOCK_MEAN
elif
anchor_type
==
"BLOCK_MAX"
:
anchor_type
=
cpuinfer_ext
.
kvcache
.
AnchorType
.
BLOCK_MAX
else
:
raise
ValueError
(
f
"Unknown anchor type:
{
anchor_type
}
"
)
if
kv_type
==
"FP16"
:
kv_type
=
cpuinfer_ext
.
kvcache
.
ggml_type
.
FP16
elif
kv_type
==
"FP32"
:
assert
False
,
"FP32 is not supported yet."
kv_type
=
cpuinfer_ext
.
kvcache
.
ggml_type
.
FP32
elif
kv_type
==
"Q4_0"
:
kv_type
=
cpuinfer_ext
.
kvcache
.
ggml_type
.
Q4_0
elif
kv_type
==
"Q8_0"
:
kv_type
=
cpuinfer_ext
.
kvcache
.
ggml_type
.
Q8_0
else
:
raise
ValueError
(
f
"Unknown kv type:
{
kv_type
}
"
)
if
retrieval_type
==
"SHARED"
:
retrieval_type
=
cpuinfer_ext
.
kvcache
.
RetrievalType
.
LAYER
elif
retrieval_type
==
"INDIVIDUAL"
:
retrieval_type
=
cpuinfer_ext
.
kvcache
.
RetrievalType
.
QHEAD
elif
retrieval_type
==
"SEPARATE"
:
retrieval_type
=
cpuinfer_ext
.
kvcache
.
RetrievalType
.
KVHEAD
self
.
config
=
cpuinfer_ext
.
kvcache
.
KVCacheConfig
(
layer_num
,
kv_head_num
,
q_head_num
,
head_dim
,
block_len
,
anchor_num
,
anchor_type
,
kv_type
,
retrieval_type
,
layer_step
,
token_step
,
layer_offset
,
max_block_num
,
max_batch_size
,
max_thread_num
,
)
self
.
kvcache
=
cpuinfer_ext
.
kvcache
.
KVCache
(
self
.
config
)
def
load_kvcache
(
self
,
tensor_file_path
:
str
):
if
not
os
.
path
.
exists
(
tensor_file_path
):
raise
FileNotFoundError
(
f
"The file
{
tensor_file_path
}
does not exist."
)
return
self
.
kvcache
.
load_kvcache
(
tensor_file_path
,)
def
dump_kvcache
(
self
,
block_table
:
torch
.
Tensor
,
cache_total_len
:
int
,
tensor_file_path
:
str
):
assert
(
block_table
.
dim
()
==
1
and
block_table
.
dtype
==
torch
.
int
and
block_table
.
is_contiguous
()
and
block_table
.
device
==
torch
.
device
(
"cpu"
)
),
"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
block_table
.
dim
(),
block_table
.
size
(),
block_table
.
dtype
,
block_table
.
is_contiguous
(),
block_table
.
device
,
)
assert
(
cache_total_len
>
0
and
cache_total_len
<=
self
.
config
.
block_len
*
block_table
.
size
(
0
)
),
"cache_total_len: {}"
.
format
(
cache_total_len
)
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
tensor_file_path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
tensor_file_path
))
return
self
.
kvcache
.
dump_kvcache
(
block_table
.
data_ptr
(),
cache_total_len
,
tensor_file_path
,
)
def
update_cache_total_len
(
self
,
cache_total_len
:
int
):
assert
cache_total_len
>
0
,
"cache_total_len: {}"
.
format
(
cache_total_len
)
self
.
kvcache
.
update_cache_total_len
(
cache_total_len
)
# q_in: (bsz, q_len, q_head_num, head_dim)
# output: (bsz, q_len, q_head_num, head_dim)
# attn_lse: (bsz, q_len, q_head_num)
# block_table: (bsz, max_block_num)
def
attn
(
self
,
q_in
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
attn_lse
:
torch
.
Tensor
,
layer_idx
:
int
,
generate_token_idx
:
int
,
block_table
:
torch
.
Tensor
|
None
=
None
,
cache_seqlens
:
torch
.
Tensor
|
None
=
None
,
pick_block_num
:
int
|
None
=
None
,
init_block_num
:
int
|
None
=
None
,
local_block_num
:
int
|
None
=
None
,
):
assert
(
q_in
.
dim
()
==
4
and
q_in
.
size
(
2
)
==
self
.
config
.
q_head_num
and
q_in
.
size
(
3
)
==
self
.
config
.
head_dim
and
q_in
.
dtype
==
torch
.
float16
and
q_in
.
is_contiguous
()
and
q_in
.
device
==
torch
.
device
(
"cpu"
)
),
"q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
q_in
.
dim
(),
q_in
.
size
(),
q_in
.
dtype
,
q_in
.
is_contiguous
(),
q_in
.
device
)
batch_size
=
q_in
.
size
(
0
)
q_len
=
q_in
.
size
(
1
)
assert
(
block_table
is
None
)
or
(
block_table
.
dim
()
==
2
and
block_table
.
size
(
0
)
==
batch_size
and
block_table
.
dtype
==
torch
.
int
and
block_table
.
is_contiguous
()
and
block_table
.
device
==
torch
.
device
(
"cpu"
)
),
"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
block_table
.
dim
(),
block_table
.
size
(),
block_table
.
dtype
,
block_table
.
is_contiguous
(),
block_table
.
device
,
)
max_block_num
=
block_table
.
size
(
1
)
if
block_table
is
not
None
else
0
assert
(
output
.
dim
()
==
4
and
output
.
size
(
0
)
==
batch_size
and
output
.
size
(
2
)
==
self
.
config
.
q_head_num
and
output
.
size
(
1
)
==
q_len
and
output
.
size
(
3
)
==
self
.
config
.
head_dim
and
output
.
dtype
==
torch
.
float16
and
output
.
is_contiguous
()
and
output
.
device
==
torch
.
device
(
"cpu"
)
),
"output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
output
.
dim
(),
output
.
size
(),
output
.
dtype
,
output
.
is_contiguous
(),
output
.
device
,
)
assert
(
attn_lse
.
dim
()
==
3
and
attn_lse
.
size
(
0
)
==
batch_size
and
attn_lse
.
size
(
1
)
==
q_len
and
attn_lse
.
size
(
2
)
==
self
.
config
.
q_head_num
and
attn_lse
.
dtype
==
torch
.
float32
and
attn_lse
.
is_contiguous
()
and
attn_lse
.
device
==
torch
.
device
(
"cpu"
)
),
"attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
attn_lse
.
dim
(),
attn_lse
.
size
(),
attn_lse
.
dtype
,
attn_lse
.
is_contiguous
(),
attn_lse
.
device
,
)
assert
(
layer_idx
>=
0
and
layer_idx
<
self
.
config
.
layer_num
),
"layer_idx: {}"
.
format
(
layer_idx
)
assert
(
cache_seqlens
is
None
)
or
(
cache_seqlens
.
dim
()
==
1
and
cache_seqlens
.
size
(
0
)
==
batch_size
and
cache_seqlens
.
dtype
==
torch
.
int
and
cache_seqlens
.
is_contiguous
()
and
cache_seqlens
.
device
==
torch
.
device
(
"cpu"
)
),
"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
cache_seqlens
.
dim
(),
cache_seqlens
.
size
(),
cache_seqlens
.
dtype
,
cache_seqlens
.
is_contiguous
(),
cache_seqlens
.
device
,
)
return
self
.
kvcache
.
attn
(
q_in
.
data_ptr
(),
output
.
data_ptr
(),
attn_lse
.
data_ptr
(),
layer_idx
,
generate_token_idx
,
q_len
,
batch_size
,
max_block_num
,
block_table
.
data_ptr
()
if
block_table
is
not
None
else
0
,
cache_seqlens
.
data_ptr
()
if
cache_seqlens
is
not
None
else
0
,
pick_block_num
,
init_block_num
,
local_block_num
,
)
# k_in: (block_len, kv_head_num, head_dim)
# v_in: (block_len, kv_head_num, head_dim)
def
update_kvcache_one_block_fp16
(
self
,
k_in
:
torch
.
Tensor
,
v_in
:
torch
.
Tensor
,
layer_id
:
int
,
block_idx
:
int
):
assert
(
k_in
.
dim
()
==
3
and
k_in
.
size
(
1
)
==
self
.
config
.
block_len
and
k_in
.
size
(
0
)
==
self
.
config
.
kv_head_num
and
k_in
.
size
(
2
)
==
self
.
config
.
head_dim
and
k_in
.
dtype
==
torch
.
float16
and
k_in
.
is_contiguous
()
and
k_in
.
device
==
torch
.
device
(
"cpu"
)
),
"k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
k_in
.
dim
(),
k_in
.
size
(),
k_in
.
dtype
,
k_in
.
is_contiguous
(),
k_in
.
device
)
assert
(
v_in
.
dim
()
==
3
and
v_in
.
size
(
1
)
==
self
.
config
.
block_len
and
v_in
.
size
(
0
)
==
self
.
config
.
kv_head_num
and
v_in
.
size
(
2
)
==
self
.
config
.
head_dim
and
v_in
.
dtype
==
torch
.
float16
and
v_in
.
is_contiguous
()
and
v_in
.
device
==
torch
.
device
(
"cpu"
)
),
"v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
v_in
.
dim
(),
v_in
.
size
(),
v_in
.
dtype
,
v_in
.
is_contiguous
(),
v_in
.
device
)
assert
(
layer_id
>=
0
and
layer_id
<
self
.
config
.
layer_num
),
"layer_id: {}"
.
format
(
layer_id
)
assert
block_idx
>=
0
,
"block_idx: {}"
.
format
(
block_idx
)
return
self
.
kvcache
.
update_one_block_fp16
(
k_in
.
data_ptr
(),
v_in
.
data_ptr
(),
layer_id
,
block_idx
,
)
def
get_kvcache_one_block_fp16
(
self
,
k_in
:
torch
.
Tensor
,
v_in
:
torch
.
Tensor
,
layer_id
:
int
,
block_idx
:
int
):
assert
(
k_in
.
dim
()
==
3
and
k_in
.
size
(
1
)
==
self
.
config
.
block_len
and
k_in
.
size
(
0
)
==
self
.
config
.
kv_head_num
and
k_in
.
size
(
2
)
==
self
.
config
.
head_dim
and
k_in
.
dtype
==
torch
.
float16
and
k_in
.
is_contiguous
()
and
k_in
.
device
==
torch
.
device
(
"cpu"
)
),
"k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
k_in
.
dim
(),
k_in
.
size
(),
k_in
.
dtype
,
k_in
.
is_contiguous
(),
k_in
.
device
)
assert
(
v_in
.
dim
()
==
3
and
v_in
.
size
(
1
)
==
self
.
config
.
block_len
and
v_in
.
size
(
0
)
==
self
.
config
.
kv_head_num
and
v_in
.
size
(
2
)
==
self
.
config
.
head_dim
and
v_in
.
dtype
==
torch
.
float16
and
v_in
.
is_contiguous
()
and
v_in
.
device
==
torch
.
device
(
"cpu"
)
),
"v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
v_in
.
dim
(),
v_in
.
size
(),
v_in
.
dtype
,
v_in
.
is_contiguous
(),
v_in
.
device
)
assert
(
layer_id
>=
0
and
layer_id
<
self
.
config
.
layer_num
),
"layer_id: {}"
.
format
(
layer_id
)
assert
block_idx
>=
0
,
"block_idx: {}"
.
format
(
block_idx
)
return
self
.
kvcache
.
get_one_block_fp16
(
k_in
.
data_ptr
(),
v_in
.
data_ptr
(),
layer_id
,
block_idx
,
)
def
update_importance_one_block
(
self
,
importance
:
torch
.
Tensor
,
layer_id
:
int
,
block_idx
:
int
):
assert
(
importance
.
dim
()
==
1
and
importance
.
size
(
0
)
==
self
.
config
.
block_len
and
importance
.
dtype
==
torch
.
float16
and
importance
.
is_contiguous
()
and
importance
.
device
==
torch
.
device
(
"cpu"
)
),
"importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
importance
.
dim
(),
importance
.
size
(),
importance
.
dtype
,
importance
.
is_contiguous
(),
importance
.
device
,
)
assert
(
layer_id
>=
0
and
layer_id
<
self
.
config
.
layer_num
),
"layer_id: {}"
.
format
(
layer_id
)
assert
block_idx
>=
0
,
"block_idx: {}"
.
format
(
block_idx
)
return
self
.
kvcache
.
update_importance_one_block
(
importance
.
data_ptr
(),
layer_id
,
block_idx
,
)
def
get_importance_one_block
(
self
,
importance
:
torch
.
Tensor
,
layer_id
:
int
,
block_idx
:
int
):
assert
(
importance
.
dim
()
==
1
and
importance
.
size
(
0
)
==
self
.
config
.
block_len
and
importance
.
dtype
==
torch
.
float16
and
importance
.
is_contiguous
()
and
importance
.
device
==
torch
.
device
(
"cpu"
)
),
"importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
importance
.
dim
(),
importance
.
size
(),
importance
.
dtype
,
importance
.
is_contiguous
(),
importance
.
device
,
)
assert
(
layer_id
>=
0
and
layer_id
<
self
.
config
.
layer_num
),
"layer_id: {}"
.
format
(
layer_id
)
assert
block_idx
>=
0
,
"block_idx: {}"
.
format
(
block_idx
)
return
self
.
kvcache
.
get_importance_one_block
(
importance
.
data_ptr
(),
layer_id
,
block_idx
,
)
def
get_anchor_one_block
(
self
,
anchor
:
torch
.
Tensor
,
layer_id
:
int
,
block_idx
:
int
):
assert
(
anchor
.
dim
()
==
3
and
anchor
.
size
(
0
)
==
self
.
config
.
kv_head_num
and
anchor
.
size
(
1
)
==
self
.
config
.
anchor_num
and
anchor
.
size
(
2
)
==
self
.
config
.
head_dim
and
anchor
.
dtype
==
torch
.
float16
and
anchor
.
is_contiguous
()
and
anchor
.
device
==
torch
.
device
(
"cpu"
)
),
"anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
anchor
.
dim
(),
anchor
.
size
(),
anchor
.
dtype
,
anchor
.
is_contiguous
(),
anchor
.
device
,
)
assert
(
layer_id
>=
0
and
layer_id
<
self
.
config
.
layer_num
),
"layer_id: {}"
.
format
(
layer_id
)
assert
block_idx
>=
0
,
"block_idx: {}"
.
format
(
block_idx
)
return
self
.
kvcache
.
get_anchor_one_block
(
anchor
.
data_ptr
(),
layer_id
,
block_idx
,
)
def
update_anchor_one_block
(
self
,
anchor
:
torch
.
Tensor
,
layer_id
:
int
,
block_idx
:
int
):
assert
(
anchor
.
dim
()
==
3
and
anchor
.
size
(
0
)
==
self
.
config
.
kv_head_num
and
anchor
.
size
(
1
)
==
self
.
config
.
anchor_num
and
anchor
.
size
(
2
)
==
self
.
config
.
head_dim
and
anchor
.
dtype
==
torch
.
float16
and
anchor
.
is_contiguous
()
and
anchor
.
device
==
torch
.
device
(
"cpu"
)
),
"anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
anchor
.
dim
(),
anchor
.
size
(),
anchor
.
dtype
,
anchor
.
is_contiguous
(),
anchor
.
device
,
)
assert
(
layer_id
>=
0
and
layer_id
<
self
.
config
.
layer_num
),
"layer_id: {}"
.
format
(
layer_id
)
assert
block_idx
>=
0
,
"block_idx: {}"
.
format
(
block_idx
)
return
self
.
kvcache
.
update_anchor_one_block
(
anchor
.
data_ptr
(),
layer_id
,
block_idx
,
)
def
calc_anchor_all_layers
(
self
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
):
assert
(
block_table
.
dim
()
==
2
and
block_table
.
size
(
0
)
==
cache_seqlens
.
size
(
0
)
and
block_table
.
dtype
==
torch
.
int
and
block_table
.
is_contiguous
()
and
block_table
.
device
==
torch
.
device
(
"cpu"
)
),
"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
block_table
.
dim
(),
block_table
.
size
(),
block_table
.
dtype
,
block_table
.
is_contiguous
(),
block_table
.
device
,
)
assert
(
cache_seqlens
.
dim
()
==
1
and
cache_seqlens
.
dtype
==
torch
.
int
and
cache_seqlens
.
is_contiguous
()
and
cache_seqlens
.
device
==
torch
.
device
(
"cpu"
)
),
"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
cache_seqlens
.
dim
(),
cache_seqlens
.
size
(),
cache_seqlens
.
dtype
,
cache_seqlens
.
is_contiguous
(),
cache_seqlens
.
device
,
)
batch_size
=
block_table
.
size
(
0
)
max_block_num
=
block_table
.
size
(
1
)
return
self
.
kvcache
.
calc_anchor_all_layers
(
block_table
.
data_ptr
(),
cache_seqlens
.
data_ptr
(),
batch_size
,
max_block_num
,
)
def
clear_importance_all_layers
(
self
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
):
assert
(
block_table
.
dim
()
==
2
and
block_table
.
size
(
0
)
==
cache_seqlens
.
size
(
0
)
and
block_table
.
dtype
==
torch
.
int
and
block_table
.
is_contiguous
()
and
block_table
.
device
==
torch
.
device
(
"cpu"
)
),
"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
block_table
.
dim
(),
block_table
.
size
(),
block_table
.
dtype
,
block_table
.
is_contiguous
(),
block_table
.
device
,
)
assert
(
cache_seqlens
.
dim
()
==
1
and
cache_seqlens
.
dtype
==
torch
.
int
and
cache_seqlens
.
is_contiguous
()
and
cache_seqlens
.
device
==
torch
.
device
(
"cpu"
)
),
"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}"
.
format
(
cache_seqlens
.
dim
(),
cache_seqlens
.
size
(),
cache_seqlens
.
dtype
,
cache_seqlens
.
is_contiguous
(),
cache_seqlens
.
device
,
)
batch_size
=
block_table
.
size
(
0
)
max_block_num
=
block_table
.
size
(
1
)
return
self
.
kvcache
.
clear_importance_all_layers
(
block_table
.
data_ptr
(),
cache_seqlens
.
data_ptr
(),
batch_size
,
max_block_num
,
)
def
get_cache_total_len
(
self
):
return
self
.
kvcache
.
get_cache_total_len
()
def
update_kvcache_q4
(
self
,
k_in
:
torch
.
Tensor
,
k_scales
:
torch
.
Tensor
,
v_in
:
torch
.
Tensor
,
v_scales
:
torch
.
Tensor
,
layer_id
:
int
,
seq_offset
:
int
|
None
=
None
,
seq_len
:
int
|
None
=
None
,
block_table
:
torch
.
Tensor
|
None
=
None
,
):
raise
NotImplementedError
def
update_kvcache_fp16
(
self
,
k_in
:
torch
.
Tensor
,
v_in
:
torch
.
Tensor
,
layer_idx
,
block_table
:
torch
.
Tensor
,
max_block_num
,
past_len
:
torch
.
Tensor
,
q_len
,
):
batch_size
=
block_table
.
size
(
0
)
return
self
.
kvcache
.
get_kvcache_fp16
(
k_in
.
data_ptr
(),
v_in
.
data_ptr
(),
layer_idx
,
block_table
.
data_ptr
(),
batch_size
,
max_block_num
,
past_len
.
data_ptr
(),
q_len
)
def
get_kvcache_q4
(
self
,
k_in
:
torch
.
Tensor
,
k_scales
:
torch
.
Tensor
,
v_in
:
torch
.
Tensor
,
v_scales
:
torch
.
Tensor
,
layer_id
:
int
,
seq_offset
:
int
|
None
=
None
,
seq_len
:
int
|
None
=
None
,
block_table
:
torch
.
Tensor
|
None
=
None
,
):
raise
NotImplementedError
def
get_kvcache_fp16
(
self
,
k_in
:
torch
.
Tensor
,
v_in
:
torch
.
Tensor
,
layer_id
:
int
,
layer_idx
,
block_table
:
torch
.
Tensor
,
max_block_num
,
past_len
:
torch
.
Tensor
,
):
batch_size
=
block_table
.
size
(
0
)
return
self
.
kvcache
.
get_kvcache_fp16
(
k_in
.
data_ptr
(),
v_in
.
data_ptr
(),
layer_idx
,
block_table
.
data_ptr
(),
batch_size
,
max_block_num
,
past_len
.
data_ptr
(),
)
def
get_and_update_kvcache_fp16
(
self
,
k_cache_cpu
:
torch
.
Tensor
,
v_cache_cpu
:
torch
.
Tensor
,
layer_idx
,
block_table
:
torch
.
Tensor
,
max_block_num
,
past_len
:
torch
.
Tensor
,
q_len
,
):
batch_size
=
block_table
.
size
(
0
)
return
self
.
kvcache
.
get_and_update_kvcache_fp16
(
k_cache_cpu
.
data_ptr
(),
v_cache_cpu
.
data_ptr
(),
layer_idx
,
block_table
.
data_ptr
(),
batch_size
,
max_block_num
,
past_len
.
data_ptr
(),
q_len
,
)
def
update_importance
(
self
,
importance_cache
:
torch
.
Tensor
,
layer_idx
,
block_table
:
torch
.
Tensor
,
max_block_num
,
offset
:
torch
.
Tensor
,
width
,
):
batch_size
=
block_table
.
size
(
0
)
return
self
.
kvcache
.
update_importance
(
importance_cache
.
data_ptr
(),
layer_idx
,
block_table
.
data_ptr
(),
batch_size
,
max_block_num
,
offset
.
data_ptr
(),
width
,
)
# attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32)
def
get_attn_sparsity
(
self
,
q_in
:
torch
.
Tensor
,
attn_sparsity
:
torch
.
Tensor
,
layer_idx
:
int
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
block_table_origin
:
torch
.
Tensor
,
cache_seqlens_origin
:
torch
.
Tensor
,
generate_token_idx
:
int
=
0
,
topk
:
int
|
None
=
None
,
local
:
int
|
None
=
None
,
):
batch_size
=
block_table
.
size
(
0
)
max_block_num
=
block_table
.
size
(
1
)
max_block_num_origin
=
block_table_origin
.
size
(
1
)
q_len
=
q_in
.
size
(
1
)
if
topk
is
None
or
local
is
None
or
topk
+
local
>=
max_block_num
:
topk
=
-
1
local
=
-
1
return
self
.
kvcache
.
get_attn_sparsity
(
q_in
.
data_ptr
(),
attn_sparsity
.
data_ptr
(),
layer_idx
,
generate_token_idx
,
q_len
,
batch_size
,
max_block_num
,
block_table
.
data_ptr
(),
cache_seqlens
.
data_ptr
(),
block_table_origin
.
data_ptr
(),
cache_seqlens_origin
.
data_ptr
(),
max_block_num_origin
,
topk
,
local
,
)
def
attn_with_kvcache
(
self
,
q_in
:
torch
.
Tensor
,
k_in
:
torch
.
Tensor
,
v_in
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
attn_lse
:
torch
.
Tensor
,
layer_idx
:
int
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
generate_token_idx
:
int
=
0
,
topk
:
int
|
None
=
None
,
local
:
int
|
None
=
None
,
):
batch_size
=
block_table
.
size
(
0
)
max_block_num
=
block_table
.
size
(
1
)
q_len
=
q_in
.
size
(
1
)
if
topk
is
None
or
local
is
None
or
topk
+
local
>=
max_block_num
:
topk
=
-
1
local
=
-
1
return
self
.
kvcache
.
attn_with_kvcache
(
q_in
.
data_ptr
(),
k_in
.
data_ptr
(),
v_in
.
data_ptr
(),
output
.
data_ptr
(),
attn_lse
.
data_ptr
(),
layer_idx
,
generate_token_idx
,
q_len
,
batch_size
,
max_block_num
,
block_table
.
data_ptr
(),
cache_seqlens
.
data_ptr
(),
topk
,
local
,
)
def
get_all_kvcache_one_layer
(
self
,
k_in
:
torch
.
Tensor
,
v_in
:
torch
.
Tensor
,
layer_id
:
int
):
return
self
.
kvcache
.
get_all_kvcache_one_layer
(
k_in
.
data_ptr
(),
v_in
.
data_ptr
(),
layer_id
,
)
def
get_importance
(
self
,
importance
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
):
raise
NotImplementedError
def
get_anchor
(
self
,
anchor
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
):
raise
NotImplementedError
class
CPUInfer
:
cpu_infer
=
None
def
__init__
(
self
,
cpu_infer
:
int
=
Config
().
cpu_infer
):
if
CPUInfer
.
cpu_infer
is
None
:
CPUInfer
.
cpu_infer
=
cpuinfer_ext
.
CPUInfer
(
cpu_infer
)
cpuinfer
=
None
def
__init__
(
self
,
thread_num
):
CPUInfer
.
cpuinfer
=
cpuinfer_ext
.
CPUInfer
(
thread_num
)
def
submit
(
self
,
task
):
CPUInfer
.
cpuinfer
.
submit
(
task
)
def
submit_with_cuda_stream
(
self
,
current_cuda_stream
,
task
):
CPUInfer
.
cpuinfer
.
submit_with_cuda_stream
(
current_cuda_stream
,
task
)
def
sync
(
self
):
CPUInfer
.
cpuinfer
.
sync
()
def
sync_with_cuda_stream
(
self
,
current_cuda_stream
):
CPUInfer
.
cpuinfer
.
sync_with_cuda_stream
(
current_cuda_stream
)
def
__getattribute__
(
self
,
__name
:
str
)
->
Any
:
return
CPUInfer
.
cpu_infer
.
__getattribute__
(
__name
)
def
__setattr__
(
self
,
__name
:
str
,
__value
:
Any
)
->
None
:
return
CPUInfer
.
cpu_infer
.
__setattr__
(
__name
,
__value
)
\ No newline at end of file
ktransformers/operators/dynamic_attention.py
0 → 100644
View file @
4d1d561d
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : Jianwei Dong
Date : 2024-08-26 23:25:24
Version : 1.0.0
LastEditors : Jianwei Dong
LastEditTime : 2024-08-26 23:25:24
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import
torch
from
transformers
import
AutoConfig
import
sys
,
os
import
logging
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
import
math
import
json
class
DynamicScaledDotProductAttention
:
remaining_length
:
int
def
__init__
(
self
,
max_seq_len
:
int
,
block_size
:
int
,
config
:
AutoConfig
,
device
:
torch
.
device
,
local_windows_len
:
int
,
topk
:
int
,
threads_num
:
int
,
anchor_type
:
str
=
"DYNAMIC"
,
kv_type
:
str
=
"FP16"
,
dense_layer_num
:
int
=
0
,
anchor_num
:
int
=
1
,
block_selection_mode
:
str
=
"SHARED"
,
layer_step
:
int
=
1
,
token_step
:
int
=
1
,
preselect_block
:
bool
=
False
,
preselect_block_count
:
int
=
96
,
prefill_chunk_size
:
int
=
20480
,
use_attn_sparsity
:
bool
=
False
,
):
# assert anchor_num == 1
# assert anchor_type == "DYNAMIC"
self
.
remaining_length
=
0
valid_anchor_types
=
[
"DYNAMIC"
,
"FIXED"
,
"BLOCK_MEAN"
,
"BLOCK_MAX"
,
"QUEST"
]
assert
anchor_type
in
valid_anchor_types
if
anchor_type
==
"QUEST"
:
assert
anchor_num
==
2
elif
anchor_type
!=
"FIXED"
and
anchor_type
!=
"DYNAMIC"
:
assert
anchor_num
==
1
valid_kv_types
=
[
"FP16"
,
"FP32"
,
"Q4_0"
,
"Q8_0"
]
assert
kv_type
in
valid_kv_types
if
kv_type
!=
"FP16"
and
kv_type
!=
"FP32"
:
assert
block_size
%
32
==
0
valid_block_selection_modes
=
[
"SHARED"
,
"SEPARATE"
]
# individual
assert
block_selection_mode
in
valid_block_selection_modes
self
.
max_seq_len
=
max_seq_len
self
.
block_num
=
max_seq_len
//
block_size
self
.
block_size
=
block_size
self
.
anchor_type
=
anchor_type
self
.
kv_type
=
kv_type
self
.
anchor_num
=
anchor_num
self
.
threads_num
=
threads_num
self
.
layer_step
=
layer_step
self
.
token_step
=
token_step
self
.
preselect_block
=
preselect_block
self
.
preselect_block_count
=
preselect_block_count
self
.
block_selection_mode
=
block_selection_mode
self
.
use_attn_sparsity
=
use_attn_sparsity
# model config
self
.
kv_head_num
=
config
.
num_key_value_heads
self
.
q_head_num
=
config
.
num_attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
self
.
layer_num
=
config
.
num_hidden_layers
self
.
device
=
device
self
.
local_windows_len
=
local_windows_len
self
.
local_block_num
=
self
.
local_windows_len
//
self
.
block_size
+
1
self
.
prefill_chunk_size
=
prefill_chunk_size
self
.
topk
=
topk
self
.
dense_layer_num
=
dense_layer_num
# self.dense_layer_num = 32
self
.
cache_key_states
=
torch
.
zeros
(
(
self
.
block_num
,
block_size
,
self
.
kv_head_num
,
self
.
head_dim
),
device
=
device
,
dtype
=
torch
.
float16
,
)
self
.
cache_value_states
=
torch
.
zeros
(
(
self
.
block_num
,
block_size
,
self
.
kv_head_num
,
self
.
head_dim
),
device
=
device
,
dtype
=
torch
.
float16
,
)
# [max_num_block, block_size, head_num]
self
.
cache_importance
=
torch
.
zeros
(
(
self
.
block_num
,
block_size
,
self
.
q_head_num
),
device
=
device
,
dtype
=
torch
.
float16
,
)
# key_states: [bsz, q_len, kv_head_num, head_dim]
# value_states: [bsz, q_len, kv_head_num, head_dim]
# query_states: [bsz, q_len, q_head_num, head_dim]
self
.
q_in_cpu
=
torch
.
zeros
(
(
1
,
1
,
self
.
q_head_num
,
self
.
head_dim
),
device
=
"cpu"
,
dtype
=
torch
.
float16
,
pin_memory
=
True
,
)
self
.
k_in_cpu
=
torch
.
zeros
(
(
1
,
1
,
self
.
kv_head_num
,
self
.
head_dim
),
device
=
"cpu"
,
dtype
=
torch
.
float16
,
pin_memory
=
True
,
)
self
.
v_in_cpu
=
torch
.
zeros
(
(
1
,
1
,
self
.
kv_head_num
,
self
.
head_dim
),
device
=
"cpu"
,
dtype
=
torch
.
float16
,
pin_memory
=
True
,
)
self
.
cache_seqlens_cpu
=
torch
.
empty
(
(
1
,),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
self
.
cache_seqlens_cuda
=
torch
.
empty
((
1
,),
device
=
device
,
dtype
=
torch
.
int32
)
self
.
prefix_block_table
=
torch
.
arange
(
self
.
block_num
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
True
).
view
(
1
,
-
1
)
self
.
block_table_cpu
=
torch
.
arange
(
self
.
block_num
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
True
).
view
(
1
,
-
1
)
# assert (
# self.local_windows_len // self.block_size + 1 + self.preselect_block_count
# <= self.block_num
# )
self
.
output_cpu
=
torch
.
empty
(
(
1
,
1
,
self
.
q_head_num
,
self
.
head_dim
),
device
=
"cpu"
,
dtype
=
torch
.
float16
,
pin_memory
=
True
,
)
self
.
lse_cpu
=
torch
.
empty
(
(
1
,
1
,
self
.
q_head_num
),
device
=
"cpu"
,
dtype
=
torch
.
float32
,
pin_memory
=
True
)
self
.
output_cuda
=
torch
.
empty
(
(
1
,
1
,
self
.
q_head_num
,
self
.
head_dim
),
device
=
device
,
dtype
=
torch
.
float16
)
self
.
attn_sparsity
=
torch
.
zeros
(
(
1
,
1
,
self
.
q_head_num
),
device
=
"cpu"
,
dtype
=
torch
.
float32
,
pin_memory
=
True
)
if
preselect_block
==
True
:
self
.
preselect_block_table
=
torch
.
zeros
(
self
.
layer_num
,
self
.
preselect_block_count
,
device
=
device
,
dtype
=
torch
.
int32
,
)
self
.
preselect_block_num
=
0
# block_num before preselect
self
.
evict_tokens
=
0
self
.
cpu_infer
=
CPUInfer
(
threads_num
)
self
.
local_thread
=
CPUInferKVCache
(
self
.
layer_num
,
self
.
kv_head_num
,
self
.
q_head_num
,
self
.
head_dim
,
self
.
block_size
,
anchor_num
=
self
.
anchor_num
,
anchor_type
=
anchor_type
,
kv_type
=
self
.
kv_type
,
retrieval_type
=
self
.
block_selection_mode
,
layer_step
=
self
.
layer_step
,
token_step
=
self
.
token_step
,
layer_offset
=
self
.
dense_layer_num
%
self
.
layer_step
,
max_batch_size
=
1
,
max_block_num
=
self
.
block_num
,
max_thread_num
=
self
.
threads_num
,
)
print
(
f
"local_windows_len:
{
local_windows_len
}
, topk:
{
topk
}
, dense_layer_num:
{
dense_layer_num
}
, kv_type:
{
self
.
kv_type
}
, anchor_type:
{
self
.
anchor_type
}
, preselect_block:
{
self
.
preselect_block
}
, preselect_block_count:
{
self
.
preselect_block_count
}
, token_step:
{
self
.
token_step
}
, layer_step:
{
self
.
layer_step
}
"
)
self
.
shape_mask
=
(
self
.
q_head_num
,
self
.
block_size
,
self
.
block_size
,
)
mask
=
torch
.
zeros
(
self
.
shape_mask
,
dtype
=
torch
.
uint8
,
device
=
device
).
contiguous
()
elm_idx
=
torch
.
arange
(
self
.
block_size
,
device
=
device
)
for
i
in
range
(
mask
.
size
(
-
2
)):
idx
=
i
+
mask
.
size
(
-
1
)
-
mask
.
size
(
-
2
)
-
elm_idx
idx
=
idx
[
idx
>=
0
]
mask
[...,
i
,
idx
]
=
1
self
.
tril_mask
=
mask
self
.
triu_mask
=
mask
^
1
self
.
generate_token_idx
=
0
def
get_attn_score_one_block
(
self
,
batch_idx
:
int
,
max_block_num
:
int
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offset
:
int
,
width
:
int
,
mask_mode
:
str
|
None
=
None
,
use_softmax
:
bool
=
True
,
):
n_rep
=
self
.
q_head_num
//
self
.
kv_head_num
importance
=
self
.
cache_importance
.
view
(
-
1
,
self
.
q_head_num
)
importance
=
importance
.
narrow
(
0
,
batch_idx
*
max_block_num
+
offset
,
width
)
n_gqa_
=
self
.
q_head_num
//
self
.
kv_head_num
for
head_idx
in
range
(
self
.
q_head_num
):
key_item
=
key
[...,
head_idx
//
n_gqa_
,
:].
view
(
key
.
size
(
0
),
-
1
)
qk
=
torch
.
einsum
(
"qd,kd->qk"
,
query
[:,
head_idx
,:],
key_item
)
# (num_attention_heads, len_q, len_k)
if
mask_mode
==
"tril"
:
mask
=
self
.
tril_mask
mask
=
mask
[
0
,
-
qk
.
size
(
-
2
)
:,
-
qk
.
size
(
-
1
)
:]
qk
=
qk
*
mask
elif
mask_mode
==
"triu"
:
mask
=
self
.
triu_mask
mask
=
mask
[
0
,
-
qk
.
size
(
-
2
)
:,
-
qk
.
size
(
-
1
)
:]
qk
=
qk
*
mask
if
use_softmax
:
qk
=
torch
.
nn
.
functional
.
softmax
(
qk
/
math
.
sqrt
(
self
.
head_dim
),
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
torch
.
float16
)
qk
=
torch
.
sum
(
qk
,
dim
=-
2
)
importance
[...,
head_idx
]
+=
qk
def
get_preselect_block_table_and_attn_score
(
self
,
layer_idx
:
int
,
batch_size
:
int
,
offset
:
torch
.
Tensor
,
width
:
int
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
union_with_last_layer
:
bool
=
True
,
):
max_seqs_len
=
offset
.
max
().
item
()
+
width
max_block_num
=
(
max_seqs_len
+
self
.
block_size
-
1
)
//
self
.
block_size
for
batch_idx
in
range
(
batch_size
):
query_cur
=
query
[
batch_idx
][
-
128
:]
self
.
get_attn_score_one_block
(
batch_idx
,
max_block_num
,
query_cur
,
key
[
batch_idx
][:
offset
[
batch_idx
].
item
()
+
width
],
0
,
offset
[
batch_idx
].
item
()
+
width
,
mask_mode
=
None
,
)
if
self
.
preselect_block
:
self
.
prefill_block_num
=
max
(
0
,
max_block_num
-
self
.
local_windows_len
//
self
.
block_size
)
self
.
evict_tokens
=
(
max
(
self
.
prefill_block_num
-
self
.
preselect_block_count
,
0
)
*
self
.
block_size
)
if
self
.
prefill_block_num
!=
0
:
importance_cache
=
self
.
cache_importance
.
narrow
(
0
,
0
,
self
.
prefill_block_num
*
batch_size
).
view
(
batch_size
,
self
.
prefill_block_num
,
self
.
block_size
,
self
.
q_head_num
)
importance_r
=
importance_cache
[:,
1
:,
:
self
.
block_size
//
4
]
pad_r
=
torch
.
zeros_like
(
importance_r
[:,
:
1
])
importance_r
=
torch
.
cat
((
importance_r
,
pad_r
),
dim
=
1
)
importance_l
=
importance_cache
[:,
:
-
1
,
-
self
.
block_size
//
4
:]
pad_l
=
torch
.
zeros_like
(
importance_l
[:,
:
1
])
importance_l
=
torch
.
cat
((
pad_l
,
importance_l
),
dim
=
1
)
importance
=
torch
.
cat
(
(
importance_l
,
importance_cache
,
importance_r
),
dim
=
2
)
importance
=
importance
.
mean
(
dim
=-
1
)
importance
=
importance
.
mean
(
dim
=-
1
)
# importance: (batch_size, max_block_num)
topk
=
min
(
self
.
preselect_block_count
,
self
.
prefill_block_num
)
values
,
indices
=
torch
.
topk
(
importance
,
k
=
topk
,
dim
=
1
,
)
self
.
preselect_block_table
[
layer_idx
:
layer_idx
+
1
,
:
topk
,
].
copy_
(
indices
)
if
union_with_last_layer
and
layer_idx
==
31
:
for
tmp_layer_idx
in
range
(
self
.
layer_num
-
1
):
for
i
in
range
(
1
,
min
(
topk
,
6
)):
x
=
self
.
preselect_block_table
[
-
1
,
i
]
if
x
not
in
self
.
preselect_block_table
[
tmp_layer_idx
]:
self
.
preselect_block_table
[
tmp_layer_idx
,
topk
-
i
]
=
x
if
self
.
anchor_type
==
"DYNAMIC"
:
importance_cache
=
self
.
cache_importance
.
narrow
(
0
,
0
,
max_block_num
*
batch_size
).
view
(
batch_size
,
max_block_num
*
self
.
block_size
,
self
.
q_head_num
)
importance_cache_cpu
=
torch
.
empty_like
(
importance_cache
,
device
=
"cpu"
,
pin_memory
=
True
)
importance_cache_cpu
.
copy_
(
importance_cache
)
block_table_cpu
=
self
.
prefix_block_table
[:,
:
max_block_num
].
to
(
"cpu"
)
offset_cpu
=
offset
.
contiguous
().
to
(
"cpu"
)
self
.
cpu_infer
.
submit
(
self
.
local_thread
.
update_importance
(
importance_cache_cpu
,
layer_idx
,
block_table_cpu
,
max_block_num
,
offset_cpu
,
width
,
)
)
self
.
cpu_infer
.
sync
()
importance_cache
=
self
.
cache_importance
.
narrow
(
0
,
0
,
max_block_num
*
batch_size
).
view
(
batch_size
,
max_block_num
*
self
.
block_size
,
self
.
q_head_num
)
importance_cache
.
zero_
()
# key: [bsz, past_len, head_num, head_dim] float16
# query: [bsz, q_len, q_head_num, head_dim] float16
def
get_attn_score
(
self
,
layer_idx
:
int
,
batch_size
:
int
,
offset
:
torch
.
Tensor
,
width
:
int
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
):
max_seqs_len
=
offset
.
max
().
item
()
+
width
max_block_num
=
(
max_seqs_len
+
self
.
block_size
-
1
)
//
self
.
block_size
for
batch_idx
in
range
(
batch_size
):
for
idx
in
range
(
width
//
self
.
block_size
):
offset_cur
=
idx
*
self
.
block_size
query_cur
=
query
[
batch_idx
,
offset_cur
:
offset_cur
+
self
.
block_size
]
self
.
get_attn_score_one_block
(
batch_idx
,
max_block_num
,
query_cur
,
key
[
batch_idx
,
offset
[
batch_idx
]
+
offset_cur
:
offset
[
batch_idx
]
+
offset_cur
+
self
.
block_size
,
],
offset
[
batch_idx
].
item
()
+
offset_cur
,
self
.
block_size
,
mask_mode
=
"tril"
,
use_softmax
=
False
,
)
offset_key
=
(
offset
[
batch_idx
].
item
()
+
idx
*
self
.
block_size
-
self
.
local_windows_len
)
if
offset_key
>=
0
:
self
.
get_attn_score_one_block
(
batch_idx
,
max_block_num
,
query_cur
,
key
[
batch_idx
,
offset_key
:
offset_key
+
self
.
block_size
],
offset_key
,
self
.
block_size
,
mask_mode
=
"triu"
,
use_softmax
=
False
,
)
offset_key
=
max
(
0
,
offset_key
+
self
.
block_size
)
width_key
=
(
offset
[
batch_idx
].
item
()
+
idx
*
self
.
block_size
-
offset_key
)
if
width_key
>
0
:
self
.
get_attn_score_one_block
(
batch_idx
,
max_block_num
,
query_cur
,
key
[
batch_idx
,
offset_key
:
offset_key
+
width_key
],
offset_key
,
width_key
,
mask_mode
=
None
,
use_softmax
=
False
,
)
importance_cache
=
self
.
cache_importance
.
narrow
(
0
,
0
,
max_block_num
*
batch_size
).
view
(
batch_size
,
max_block_num
*
self
.
block_size
,
self
.
q_head_num
)
importance_cache_cpu
=
torch
.
empty_like
(
importance_cache
,
device
=
"cpu"
,
pin_memory
=
True
)
importance_cache_cpu
.
copy_
(
importance_cache
)
block_table_cpu
=
self
.
prefix_block_table
[:,
:
max_block_num
].
to
(
"cpu"
)
offset_cpu
=
offset
.
contiguous
().
to
(
"cpu"
)
self
.
cpu_infer
.
submit
(
self
.
local_thread
.
update_importance
(
importance_cache_cpu
,
layer_idx
,
block_table_cpu
,
max_block_num
,
offset_cpu
,
width
,
)
)
self
.
cpu_infer
.
sync
()
importance_cache
.
zero_
()
# key: [bsz, q_len, head_num, head_dim] float16
# value: [bsz, q_len, head_num, head_dim] float16
def
swap_in_and_swap_out
(
self
,
layer_idx
,
past_len
,
q_len
,
key
,
value
):
batch_size
=
1
max_seqs_len
=
past_len
.
max
().
item
()
+
q_len
max_block_num
=
(
max_seqs_len
+
self
.
block_size
-
1
)
//
self
.
block_size
k_cache
=
self
.
cache_key_states
.
narrow
(
0
,
0
,
max_block_num
*
batch_size
).
view
(
batch_size
,
max_block_num
*
self
.
block_size
,
self
.
kv_head_num
,
self
.
head_dim
)
v_cache
=
self
.
cache_value_states
.
narrow
(
0
,
0
,
max_block_num
*
batch_size
).
view
(
batch_size
,
max_block_num
*
self
.
block_size
,
self
.
kv_head_num
,
self
.
head_dim
)
for
batch_idx
in
range
(
batch_size
):
offset
=
past_len
[
batch_idx
]
width
=
q_len
k_cache
[
batch_idx
][
offset
:
offset
+
width
].
copy_
(
key
[
batch_idx
].
view
(
-
1
,
self
.
kv_head_num
,
self
.
head_dim
)
)
v_cache
[
batch_idx
][
offset
:
offset
+
width
].
copy_
(
value
[
batch_idx
].
view
(
-
1
,
self
.
kv_head_num
,
self
.
head_dim
)
)
k_cache_cpu
=
torch
.
empty_like
(
k_cache
,
device
=
"cpu"
,
pin_memory
=
True
)
v_cache_cpu
=
torch
.
empty_like
(
v_cache
,
device
=
"cpu"
,
pin_memory
=
True
)
k_cache_cpu
.
copy_
(
k_cache
)
v_cache_cpu
.
copy_
(
v_cache
)
cur_block_num
=
(
q_len
+
past_len
[
0
].
item
()
+
self
.
block_size
-
1
)
//
self
.
block_size
block_table_cpu
=
self
.
prefix_block_table
[:,
:
cur_block_num
].
to
(
"cpu"
)
past_len_cpu
=
past_len
.
contiguous
().
to
(
"cpu"
)
self
.
cpu_infer
.
submit
(
self
.
local_thread
.
get_and_update_kvcache_fp16
(
k_cache_cpu
,
v_cache_cpu
,
layer_idx
,
block_table_cpu
,
max_block_num
,
past_len_cpu
,
q_len
,
)
)
self
.
cpu_infer
.
sync
()
k_cache
.
copy_
(
k_cache_cpu
)
v_cache
.
copy_
(
v_cache_cpu
)
return
k_cache
,
v_cache
def
calc_anchor
(
self
,
cache_seqlens
:
int
):
cur_block_num
=
(
cache_seqlens
+
self
.
block_size
-
1
)
//
self
.
block_size
block_table_cpu
=
self
.
prefix_block_table
[:,
:
cur_block_num
].
to
(
"cpu"
)
cache_seqlens_cpu
=
torch
.
tensor
(
[
cache_seqlens
],
device
=
"cpu"
,
dtype
=
torch
.
int32
)
self
.
cpu_infer
.
submit
(
self
.
local_thread
.
calc_anchor_all_layers
(
block_table_cpu
,
cache_seqlens_cpu
,
)
)
self
.
cpu_infer
.
sync
()
def
clear_importance
(
self
,
cache_seqlens
:
int
):
print
(
f
"clear importance:
{
cache_seqlens
}
"
)
cur_block_num
=
(
cache_seqlens
+
self
.
block_size
-
1
)
//
self
.
block_size
block_table_cpu
=
self
.
prefix_block_table
[:,
:
cur_block_num
].
to
(
"cpu"
)
cache_seqlens_cpu
=
torch
.
tensor
(
[
cache_seqlens
],
device
=
"cpu"
,
dtype
=
torch
.
int32
)
self
.
cpu_infer
.
submit
(
self
.
local_thread
.
clear_importance_all_layers
(
block_table_cpu
,
cache_seqlens_cpu
,
)
)
self
.
cpu_infer
.
sync
()
def
clear_kvcache
(
self
,
cache_seqlens
:
int
):
cur_block_num
=
(
cache_seqlens
+
self
.
block_size
-
1
)
//
self
.
block_size
block_table_cpu
=
self
.
prefix_block_table
[:,
:
cur_block_num
].
to
(
"cpu"
)
cache_seqlens_cpu
=
torch
.
tensor
(
[
cache_seqlens
],
device
=
"cpu"
,
dtype
=
torch
.
int32
)
self
.
cpu_infer
.
submit
(
self
.
local_thread
.
clear_kvcache_all_layers
(
block_table_cpu
,
cache_seqlens_cpu
,
)
)
self
.
cpu_infer
.
sync
()
def
get_attn_sparsity
(
self
,
q_in
:
torch
.
Tensor
,
layer_idx
:
int
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
block_table_origin
:
torch
.
Tensor
,
cache_seqlens_origin
:
torch
.
Tensor
,
generate_token_idx
:
int
=
0
,
topk
:
int
|
None
=
None
,
local
:
int
|
None
=
None
,
output_path
:
str
=
"./attn_sparsity.json"
,
):
self
.
attn_sparsity
.
zero_
()
self
.
pcinfer
.
submit
(
self
.
local_thread
.
get_attn_sparsity
(
q_in
,
self
.
attn_sparsity
,
layer_idx
,
block_table
,
cache_seqlens
,
block_table_origin
,
cache_seqlens_origin
,
generate_token_idx
,
topk
,
local
,
)
)
self
.
cpu_infer
.
sync
()
with
open
(
output_path
,
"a"
)
as
file
:
for
head_idx
in
range
(
self
.
q_head_num
):
sparsity
=
self
.
attn_sparsity
[
0
][
0
][
head_idx
].
item
()
json_obj
=
{
"token_idx"
:
generate_token_idx
,
"layer_idx"
:
layer_idx
,
"head_idx"
:
head_idx
,
"sparsity"
:
sparsity
,
}
json
.
dump
(
json_obj
,
file
)
file
.
write
(
"
\n
"
)
def
apply
(
self
,
layer_idx
:
int
,
bsz
:
int
,
past_len
:
int
,
query_states
:
torch
.
Tensor
,
key_states
:
torch
.
Tensor
,
value_states
:
torch
.
Tensor
,
mode
:
str
=
"prefill"
,
generate_token_idx
:
int
=
-
1
,
):
# key_states: [bsz, q_len, kv_head_num, head_dim]
# value_states: [bsz, q_len, kv_head_num, head_dim]
# query_states: [bsz, q_len, q_head_num, head_dim]
assert
query_states
.
dtype
==
torch
.
float16
assert
key_states
.
dtype
==
torch
.
float16
assert
value_states
.
dtype
==
torch
.
float16
assert
key_states
.
size
(
2
)
==
self
.
kv_head_num
assert
value_states
.
size
(
2
)
==
self
.
kv_head_num
assert
query_states
.
size
(
2
)
==
self
.
q_head_num
q_len
=
query_states
.
size
(
1
)
batch_size
=
query_states
.
size
(
0
)
self
.
cache_seqlens_cuda
.
fill_
(
past_len
)
last_chunk
=
False
if
self
.
remaining_length
<=
self
.
prefill_chunk_size
and
q_len
!=
1
:
last_chunk
=
True
device
=
query_states
.
device
if
layer_idx
==
0
:
if
q_len
==
1
:
self
.
generate_token_idx
+=
1
elif
last_chunk
:
self
.
generate_token_idx
=
-
1
if
mode
==
"prefill"
:
key
,
value
=
self
.
swap_in_and_swap_out
(
layer_idx
,
self
.
cache_seqlens_cuda
,
q_len
,
key_states
,
value_states
,
)
if
last_chunk
and
(
self
.
anchor_type
==
"DYNAMIC"
or
self
.
preselect_block
):
self
.
get_preselect_block_table_and_attn_score
(
layer_idx
,
bsz
,
self
.
cache_seqlens_cuda
,
q_len
,
query_states
,
key
,
)
output
=
flash_attn_with_kvcache
(
q
=
query_states
,
k_cache
=
key
,
v_cache
=
value
,
cache_seqlens
=
self
.
cache_seqlens_cuda
+
q_len
,
causal
=
True
,
)
return
output
.
transpose
(
1
,
2
)
elif
mode
==
"generate"
:
assert
self
.
generate_token_idx
>=
0
self
.
q_in_cpu
.
copy_
(
query_states
,
non_blocking
=
True
)
self
.
k_in_cpu
.
copy_
(
key_states
,
non_blocking
=
True
)
self
.
v_in_cpu
.
copy_
(
value_states
,
non_blocking
=
True
)
self
.
cache_seqlens_cpu
.
copy_
(
self
.
cache_seqlens_cuda
,
non_blocking
=
True
)
# print(layer_idx)
if
layer_idx
<
self
.
dense_layer_num
:
self
.
block_table_cpu
.
copy_
(
self
.
prefix_block_table
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
"cuda"
).
cuda_stream
,
self
.
local_thread
.
attn_with_kvcache
(
q_in
=
self
.
q_in_cpu
,
k_in
=
self
.
k_in_cpu
,
v_in
=
self
.
v_in_cpu
,
output
=
self
.
output_cpu
,
attn_lse
=
self
.
lse_cpu
,
layer_idx
=
layer_idx
,
block_table
=
self
.
block_table_cpu
,
cache_seqlens
=
self
.
cache_seqlens_cpu
,
),
)
else
:
if
self
.
preselect_block
:
self
.
cache_seqlens_cpu
.
copy_
(
self
.
cache_seqlens_cuda
-
self
.
evict_tokens
,
non_blocking
=
True
)
if
self
.
preselect_block_count
<
self
.
prefill_block_num
:
self
.
block_table_cpu
[:,
:
self
.
preselect_block_count
].
copy_
(
self
.
preselect_block_table
[
layer_idx
:
layer_idx
+
1
],
non_blocking
=
True
,
)
self
.
block_table_cpu
[
:,
self
.
preselect_block_count
:
self
.
preselect_block_count
+
self
.
local_block_num
,
].
copy_
(
self
.
prefix_block_table
[
:,
self
.
prefill_block_num
:
self
.
prefill_block_num
+
self
.
local_block_num
,
],
non_blocking
=
True
,
)
# print("submit_with_cuda_stream")
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
"cuda"
).
cuda_stream
,
self
.
local_thread
.
attn_with_kvcache
(
q_in
=
self
.
q_in_cpu
,
k_in
=
self
.
k_in_cpu
,
v_in
=
self
.
v_in_cpu
,
output
=
self
.
output_cpu
,
attn_lse
=
self
.
lse_cpu
,
layer_idx
=
layer_idx
,
generate_token_idx
=
self
.
generate_token_idx
,
block_table
=
self
.
block_table_cpu
,
cache_seqlens
=
self
.
cache_seqlens_cpu
,
topk
=
(
self
.
topk
if
self
.
topk
<=
self
.
preselect_block_count
else
None
),
local
=
self
.
local_windows_len
//
self
.
block_size
,
),
)
# print("submit_with_cuda_stream enqueue\n")
else
:
self
.
block_table_cpu
.
copy_
(
self
.
prefix_block_table
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
"cuda"
).
cuda_stream
,
self
.
local_thread
.
attn_with_kvcache
(
q_in
=
self
.
q_in_cpu
,
k_in
=
self
.
k_in_cpu
,
v_in
=
self
.
v_in_cpu
,
output
=
self
.
output_cpu
,
attn_lse
=
self
.
lse_cpu
,
layer_idx
=
layer_idx
,
generate_token_idx
=
self
.
generate_token_idx
,
block_table
=
self
.
block_table_cpu
,
cache_seqlens
=
self
.
cache_seqlens_cpu
,
topk
=
self
.
topk
,
local
=
self
.
local_windows_len
//
self
.
block_size
,
),
)
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
"cuda"
).
cuda_stream
)
# print("submit_with_cuda_stream finished\n")
self
.
output_cuda
.
copy_
(
self
.
output_cpu
,
non_blocking
=
True
)
return
self
.
output_cuda
.
transpose
(
1
,
2
)
def
save
(
self
,
path
:
str
,
length
:
int
):
cur_block_num
=
(
length
+
self
.
block_size
-
1
)
//
self
.
block_size
block_table_cpu
=
self
.
prefix_block_table
[
0
,
:
cur_block_num
].
to
(
"cpu"
)
cache_seqlens_cpu
=
torch
.
tensor
([
length
],
device
=
"cpu"
,
dtype
=
torch
.
int32
)
self
.
cpu_infer
.
submit
(
self
.
local_thread
.
dump_kvcache
(
block_table_cpu
,
cache_seqlens_cpu
,
path
,
)
)
self
.
cpu_infer
.
sync
()
def
load
(
self
,
path
:
str
,
length
:
int
):
self
.
cpu_infer
.
submit
(
self
.
local_thread
.
load_kvcache
(
path
,
)
)
self
.
cpu_infer
.
sync
()
ktransformers/operators/experts.py
View file @
4d1d561d
...
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-25 11:25:24
Version : 0.1.0
LastEditors : Azure
LastEditTime : 2024-08-
15 02:36
:2
9
LastEditTime : 2024-08-
27 03:50
:2
3
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
...
...
@@ -436,7 +436,7 @@ class KExpertsTorch(KExpertsBase):
final_hidden_states
.
index_add_
(
0
,
top_x
,
current_hidden_states
)
return
final_hidden_states
.
to
(
org_dtype
,
device
=
org_device
)
return
final_hidden_states
.
to
(
dtype
=
org_dtype
,
device
=
org_device
)
EXPERTS_MAP
=
{
"KExpertsCPU"
:
KExpertsCPU
,
...
...
ktransformers/operators/models.py
View file @
4d1d561d
#!/usr/bin/env python
# coding=utf-8
'''
"""
Description :
Author : Azure-Tang
Date : 2024-07-25 11:25:24
Version : 1.0.0
LastEditors : Azure
LastEditTime : 2024-08-
14 14:53
:0
5
LastEditTime : 2024-08-
27 07:29
:0
4
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
"""
import
inspect
import
math
...
...
@@ -19,7 +19,10 @@ import torch.nn.functional as F
import
torch.utils.checkpoint
from
torch
import
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
ktransformers.operators.dynamic_attention
import
DynamicScaledDotProductAttention
from
ktransformers.server.config.config
import
Config
import
os
import
yaml
from
transformers.activations
import
ACT2FN
from
transformers.cache_utils
import
Cache
,
DynamicCache
,
StaticCache
from
transformers.modeling_attn_mask_utils
import
(
...
...
@@ -40,19 +43,35 @@ from transformers.utils import (
logging
,
replace_return_docstrings
,
)
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
,
Qwen2MoeMLP
,
Qwen2MoeDecoderLayer
from
ktransformers.models.modeling_deepseek
import
BaseModelOutputWithPast
,
DeepseekV2DecoderLayer
,
DeepseekV2MoE
from
ktransformers.models.modeling_qwen2_moe
import
(
Qwen2MoeSparseMoeBlock
,
Qwen2MoeMLP
,
Qwen2MoeDecoderLayer
,
)
from
ktransformers.models.modeling_deepseek
import
(
BaseModelOutputWithPast
,
DeepseekV2DecoderLayer
,
DeepseekV2MoE
,
)
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers.configuration_utils
import
PretrainedConfig
from
ktransformers.models.modeling_llama
import
(
LlamaDecoderLayer
,
LlamaRMSNorm
,
LlamaRotaryEmbedding
,
)
if
is_flash_attn_2_available
():
from
flash_attn
import
flash_attn_func
,
flash_attn_varlen_func
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
unpad_input
# noqa
_flash_supports_window_size
=
"window_size"
in
list
(
inspect
.
signature
(
flash_attn_func
).
parameters
)
_flash_supports_window_size
=
"window_size"
in
list
(
inspect
.
signature
(
flash_attn_func
).
parameters
)
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -151,6 +170,7 @@ QWEN2MOE_INPUTS_DOCSTRING = r"""
the complete sequence length.
"""
@
add_start_docstrings
(
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top."
,
QWEN2MOE_START_DOCSTRING
,
...
...
@@ -162,18 +182,21 @@ class KQwen2MoeModel(BaseInjectedModule):
Args:
config: Qwen2MoeConfig
"""
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
per_layer_prefill_intput_threshold
:
int
=
30000
,
# if None, no per-layer prefill
per_layer_prefill_intput_threshold
:
int
=
30000
,
# if None, no per-layer prefill
transfer_map
:
dict
=
None
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
per_layer_prefill_intput_threshold
=
per_layer_prefill_intput_threshold
self
.
transfer_map
=
transfer_map
self
.
stream_device_map
=
dict
()
...
...
@@ -192,29 +215,47 @@ class KQwen2MoeModel(BaseInjectedModule):
output_router_logits
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
per_layer_prefill_intput_threshold
:
int
|
None
=
None
,
# if None or 0, close per-layer prefill
per_layer_prefill_intput_threshold
:
(
int
|
None
)
=
None
,
# if None or 0, close per-layer prefill
)
->
Union
[
Tuple
,
MoeModelOutputWithPast
]:
# print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}')
if
per_layer_prefill_intput_threshold
is
None
:
per_layer_prefill_intput_threshold
=
self
.
per_layer_prefill_intput_threshold
if
per_layer_prefill_intput_threshold
is
None
:
per_layer_prefill_intput_threshold
=
self
.
per_layer_prefill_intput_threshold
per_layer_prefill_flag
=
False
seq_lenth
=
inputs_embeds
.
size
(
1
)
if
inputs_embeds
is
not
None
else
input_ids
.
size
(
1
)
if
per_layer_prefill_intput_threshold
and
per_layer_prefill_intput_threshold
<
seq_lenth
:
seq_lenth
=
(
inputs_embeds
.
size
(
1
)
if
inputs_embeds
is
not
None
else
input_ids
.
size
(
1
)
)
if
(
per_layer_prefill_intput_threshold
and
per_layer_prefill_intput_threshold
<
seq_lenth
):
per_layer_prefill_flag
=
True
for
layer
in
self
.
layers
:
self
.
load_layer_to
(
layer
,
InferenceState
.
UNLOAD
)
else
:
pass
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_router_logits
=
(
output_router_logits
if
output_router_logits
is
not
None
else
self
.
config
.
output_router_logits
output_router_logits
if
output_router_logits
is
not
None
else
self
.
config
.
output_router_logits
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
if
(
input_ids
is
None
)
^
(
inputs_embeds
is
not
None
):
raise
ValueError
(
...
...
@@ -243,15 +284,23 @@ class KQwen2MoeModel(BaseInjectedModule):
inputs_embeds
=
inputs_embeds
.
to
(
"cuda"
)
if
cache_position
is
None
:
past_seen_tokens
=
past_key_values
.
get_seq_length
()
if
past_key_values
is
not
None
else
0
past_seen_tokens
=
(
past_key_values
.
get_seq_length
()
if
past_key_values
is
not
None
else
0
)
cache_position
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
inputs_embeds
.
shape
[
1
],
device
=
inputs_embeds
.
device
past_seen_tokens
,
past_seen_tokens
+
inputs_embeds
.
shape
[
1
],
device
=
inputs_embeds
.
device
,
)
if
position_ids
is
None
:
position_ids
=
cache_position
.
unsqueeze
(
0
)
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
,
)
hidden_states
=
inputs_embeds
...
...
@@ -263,7 +312,7 @@ class KQwen2MoeModel(BaseInjectedModule):
next_decoder_cache
=
None
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
if
cur_device
not
in
self
.
stream_device_map
:
...
...
@@ -271,11 +320,25 @@ class KQwen2MoeModel(BaseInjectedModule):
torch
.
cuda
.
set_device
(
cur_device
)
self
.
stream_device_map
[
cur_device
].
wait_stream
(
prev_stream
)
torch
.
cuda
.
set_stream
(
self
.
stream_device_map
[
cur_device
])
hidden_states
=
hidden_states
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
causal_mask
=
causal_mask
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
causal_mask
is
not
None
else
None
position_ids
=
position_ids
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
position_ids
is
not
None
else
None
cache_position
=
cache_position
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
cache_position
is
not
None
else
None
hidden_states
=
hidden_states
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
causal_mask
=
(
causal_mask
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
causal_mask
is
not
None
else
None
)
position_ids
=
(
position_ids
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
position_ids
is
not
None
else
None
)
cache_position
=
(
cache_position
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
cache_position
is
not
None
else
None
)
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
...
...
@@ -323,7 +386,6 @@ class KQwen2MoeModel(BaseInjectedModule):
hidden_states
=
self
.
norm
(
hidden_states
)
if
per_layer_prefill_flag
:
per_layer_prefill_flag
=
False
for
layer
in
self
.
layers
:
...
...
@@ -333,12 +395,22 @@ class KQwen2MoeModel(BaseInjectedModule):
next_cache
=
None
if
use_cache
:
next_cache
=
next_decoder_cache
.
to_legacy_cache
()
if
use_legacy_cache
else
next_decoder_cache
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
use_legacy_cache
else
next_decoder_cache
)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
,
all_router_logits
]
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
,
all_router_logits
,
]
if
v
is
not
None
)
return
MoeModelOutputWithPast
(
...
...
@@ -349,11 +421,13 @@ class KQwen2MoeModel(BaseInjectedModule):
router_logits
=
all_router_logits
,
)
def
load_layer_to
(
self
,
layer
:
Qwen2MoeDecoderLayer
,
target
:
InferenceState
):
assert
isinstance
(
layer
,
Qwen2MoeDecoderLayer
),
"module should be nn.ModuleList of decoder layers"
def
load_layer_to
(
self
,
layer
:
Qwen2MoeDecoderLayer
,
target
:
InferenceState
):
assert
isinstance
(
layer
,
Qwen2MoeDecoderLayer
),
"module should be nn.ModuleList of decoder layers"
# TODO Support restore to original device, not only cuda
device
=
"cpu"
if
target
==
InferenceState
.
UNLOAD
else
"cuda"
device
=
"cpu"
if
target
==
InferenceState
.
UNLOAD
else
"cuda"
# attn
layer
.
self_attn
.
q_proj
.
set_inference_mode
(
target
)
...
...
@@ -458,18 +532,21 @@ class KDeepseekV2Model(BaseInjectedModule):
Args:
config: DeepseekV2Config
"""
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
per_layer_prefill_intput_threshold
:
int
=
30000
,
# if None, no per-layer prefill
per_layer_prefill_intput_threshold
:
int
=
30000
,
# if None, no per-layer prefill
transfer_map
:
dict
=
None
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
per_layer_prefill_intput_threshold
=
per_layer_prefill_intput_threshold
self
.
transfer_map
=
transfer_map
self
.
stream_device_map
=
dict
()
...
...
@@ -487,15 +564,23 @@ class KDeepseekV2Model(BaseInjectedModule):
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
per_layer_prefill_intput_threshold
:
int
|
None
=
None
,
# if None, no per-layer prefill
per_layer_prefill_intput_threshold
:
(
int
|
None
)
=
None
,
# if None, no per-layer prefill
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
if
per_layer_prefill_intput_threshold
is
None
:
per_layer_prefill_intput_threshold
=
self
.
per_layer_prefill_intput_threshold
if
per_layer_prefill_intput_threshold
is
None
:
per_layer_prefill_intput_threshold
=
self
.
per_layer_prefill_intput_threshold
per_layer_prefill_flag
=
False
seq_lenth
=
inputs_embeds
.
size
(
1
)
if
inputs_embeds
is
not
None
else
input_ids
.
size
(
1
)
if
per_layer_prefill_intput_threshold
and
per_layer_prefill_intput_threshold
<
seq_lenth
:
seq_lenth
=
(
inputs_embeds
.
size
(
1
)
if
inputs_embeds
is
not
None
else
input_ids
.
size
(
1
)
)
if
(
per_layer_prefill_intput_threshold
and
per_layer_prefill_intput_threshold
<
seq_lenth
):
per_layer_prefill_flag
=
True
for
layer
in
self
.
layers
:
self
.
load_layer_to
(
layer
,
InferenceState
.
UNLOAD
)
self
.
load_layer_to
(
layer
,
InferenceState
.
UNLOAD
)
torch
.
cuda
.
empty_cache
()
else
:
pass
...
...
@@ -542,9 +627,13 @@ class KDeepseekV2Model(BaseInjectedModule):
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
if
cache_position
is
None
:
past_seen_tokens
=
past_key_values
.
get_seq_length
()
if
past_key_values
is
not
None
else
0
past_seen_tokens
=
(
past_key_values
.
get_seq_length
()
if
past_key_values
is
not
None
else
0
)
cache_position
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
inputs_embeds
.
shape
[
1
],
device
=
inputs_embeds
.
device
past_seen_tokens
,
past_seen_tokens
+
inputs_embeds
.
shape
[
1
],
device
=
inputs_embeds
.
device
,
)
if
position_ids
is
None
:
...
...
@@ -556,15 +645,17 @@ class KDeepseekV2Model(BaseInjectedModule):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
input_ids
=
input_ids
.
to
(
org_device
)
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
)
if
per_layer_prefill_flag
:
causal_mask
=
None
else
:
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
)
# embed positions
hidden_states
=
inputs_embeds
if
per_layer_prefill_flag
:
print
(
f
'
Total length of input_ids:
{
hidden_states
.
size
(
1
)
}
'
)
print
(
f
"
Total length of input_ids:
{
hidden_states
.
size
(
1
)
}
"
)
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
...
...
@@ -576,7 +667,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f
=
0
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
if
cur_device
not
in
self
.
stream_device_map
:
...
...
@@ -584,10 +675,24 @@ class KDeepseekV2Model(BaseInjectedModule):
torch
.
cuda
.
set_device
(
cur_device
)
self
.
stream_device_map
[
cur_device
].
wait_stream
(
prev_stream
)
torch
.
cuda
.
set_stream
(
self
.
stream_device_map
[
cur_device
])
hidden_states
=
hidden_states
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
causal_mask
=
causal_mask
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
causal_mask
is
not
None
else
None
position_ids
=
position_ids
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
position_ids
is
not
None
else
None
cache_position
=
cache_position
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
cache_position
is
not
None
else
None
hidden_states
=
hidden_states
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
causal_mask
=
(
causal_mask
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
causal_mask
is
not
None
else
None
)
position_ids
=
(
position_ids
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
position_ids
is
not
None
else
None
)
cache_position
=
(
cache_position
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
)
if
cache_position
is
not
None
else
None
)
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
...
...
@@ -622,12 +727,12 @@ class KDeepseekV2Model(BaseInjectedModule):
t5
=
time
.
time
()
if
per_layer_prefill_flag
:
# print(f"to cpu")
self
.
load_layer_to
(
decoder_layer
,
InferenceState
.
UNLOAD
)
self
.
load_layer_to
(
decoder_layer
,
InferenceState
.
UNLOAD
)
torch
.
cuda
.
empty_cache
()
t6
=
time
.
time
()
t_gpu
+=
t4
-
t3
t_cpu
+=
t6
-
t5
t_f
+=
t5
-
t4
t_gpu
+=
t4
-
t3
t_cpu
+=
t6
-
t5
t_f
+=
t5
-
t4
hidden_states
=
layer_outputs
[
0
]
...
...
@@ -648,7 +753,9 @@ class KDeepseekV2Model(BaseInjectedModule):
torch
.
cuda
.
empty_cache
()
t7
=
time
.
time
()
print
(
f
"total time:
{
t7
-
t3
}
,
\n
layer num
{
len
(
self
.
layers
)
}
, gpu time:
{
t_gpu
}
, cpu time:
{
t_cpu
}
, forward time:
{
t_f
}
, restore time:
{
t7
-
t6
}
"
)
print
(
f
"total time:
{
t7
-
t3
}
,
\n
layer num
{
len
(
self
.
layers
)
}
, gpu time:
{
t_gpu
}
, cpu time:
{
t_cpu
}
, forward time:
{
t_f
}
, restore time:
{
t7
-
t6
}
"
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
...
...
@@ -674,16 +781,18 @@ class KDeepseekV2Model(BaseInjectedModule):
attentions
=
all_self_attns
,
)
def
load_layer_to
(
self
,
layer
:
DeepseekV2DecoderLayer
,
target
:
InferenceState
):
assert
isinstance
(
layer
,
DeepseekV2DecoderLayer
),
"module should be nn.ModuleList of decoder layers"
def
load_layer_to
(
self
,
layer
:
DeepseekV2DecoderLayer
,
target
:
InferenceState
):
assert
isinstance
(
layer
,
DeepseekV2DecoderLayer
),
"module should be nn.ModuleList of decoder layers"
# TODO Support restore to original device, not only cuda
device
=
"cpu"
if
target
==
InferenceState
.
UNLOAD
else
"cuda"
device
=
"cpu"
if
target
==
InferenceState
.
UNLOAD
else
"cuda"
# TODO Support DFS to auto use {to, set_inference_mode} according to the module type
# attn
layer
.
self_attn
.
to
(
device
)
#
layer
.
self_attn
.
to
(
device
)
#
# mlp
if
isinstance
(
layer
.
mlp
,
DeepseekV2MoE
):
...
...
@@ -702,3 +811,526 @@ class KDeepseekV2Model(BaseInjectedModule):
# layer norm
layer
.
input_layernorm
.
to
(
device
)
layer
.
post_attention_layernorm
.
to
(
device
)
LLAMA_START_DOCSTRING
=
r
"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
LLAMA_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@
add_start_docstrings
(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top."
,
LLAMA_START_DOCSTRING
,
)
class
LlamaPreTrainedModel
(
PreTrainedModel
):
config_class
=
LlamaConfig
base_model_prefix
=
"model"
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"LlamaDecoderLayer"
]
_skip_keys_device_placement
=
[
"past_key_values"
]
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_quantized_cache
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
class
KLlamaModel
(
BaseInjectedModule
):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
dynamic_sdpa
=
None
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
per_layer_prefill_intput_threshold
:
int
=
30000
,
# if None, no per-layer prefill
transfer_map
:
dict
=
None
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
per_layer_prefill_intput_threshold
=
per_layer_prefill_intput_threshold
self
.
transfer_map
=
transfer_map
self
.
stream_device_map
=
dict
()
user_path
:
str
=
os
.
path
.
expanduser
(
'~'
)
localstore_path
:
str
=
os
.
path
.
join
(
user_path
,
'.ktransformers'
)
config_path
:
str
=
os
.
path
.
join
(
localstore_path
,
Config
.
CONFIG_FILE_NAME
)
with
open
(
config_path
,
"r"
)
as
file
:
config_yaml
=
yaml
.
safe_load
(
file
.
read
())
self
.
long_context_config
=
config_yaml
.
get
(
"long_context"
)
self
.
ext_config
=
config_yaml
.
get
(
"ext"
)
KLlamaModel
.
dynamic_sdpa
=
DynamicScaledDotProductAttention
(
max_seq_len
=
self
.
long_context_config
[
"max_seq_len"
],
block_size
=
self
.
long_context_config
[
"block_size"
],
config
=
config
,
device
=
torch
.
device
(
"cuda"
),
local_windows_len
=
self
.
long_context_config
[
"local_windows_len"
],
topk
=
self
.
long_context_config
[
"second_select_num"
],
threads_num
=
self
.
ext_config
[
"cpu_infer"
],
anchor_type
=
self
.
long_context_config
[
"anchor_type"
],
kv_type
=
self
.
long_context_config
[
"kv_type"
],
dense_layer_num
=
self
.
long_context_config
[
"dense_layer_num"
],
anchor_num
=
self
.
long_context_config
[
"anchor_num"
],
preselect_block
=
self
.
long_context_config
[
"preselect_block"
],
block_selection_mode
=
self
.
long_context_config
[
"head_select_mode"
],
preselect_block_count
=
self
.
long_context_config
[
"preselect_block_count"
],
layer_step
=
self
.
long_context_config
[
"layer_step"
],
token_step
=
self
.
long_context_config
[
"token_step"
],
prefill_chunk_size
=
self
.
long_context_config
[
"chunk_size"
],
use_attn_sparsity
=
False
,
)
def
get_input_embeddings
(
self
):
return
self
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
Union
[
Cache
,
List
[
torch
.
FloatTensor
]]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
if
(
input_ids
is
None
)
^
(
inputs_embeds
is
not
None
):
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if
self
.
gradient_checkpointing
and
self
.
training
and
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache
=
False
return_legacy_cache
=
False
if
(
use_cache
and
not
isinstance
(
past_key_values
,
Cache
)
and
not
self
.
training
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
logger
.
warning_once
(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
if
cache_position
is
None
:
past_seen_tokens
=
(
past_key_values
.
get_seq_length
()
if
past_key_values
is
not
None
else
0
)
cache_position
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
inputs_embeds
.
shape
[
1
],
device
=
"cuda"
,
)
if
position_ids
is
None
:
position_ids
=
cache_position
.
unsqueeze
(
0
)
causal_mask
=
None
chunck_size
=
self
.
long_context_config
[
"chunk_size"
]
cur_idx
=
0
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
.
to
(
"cpu"
))
q_len
=
cache_position
.
size
(
0
)
# generate
if
q_len
==
1
:
x
=
inputs_embeds
[:,
-
1
:,
:]
position_ids
=
position_ids
[:,
-
1
:]
return
self
.
forward_chunk
(
x
,
causal_mask
,
position_ids
,
past_key_values
,
output_attentions
,
use_cache
,
cache_position
,
output_hidden_states
,
return_dict
,
)
elif
q_len
<=
chunck_size
:
inputs_embeds
=
inputs_embeds
.
to
(
'cuda'
)
output
=
self
.
forward_chunk
(
inputs_embeds
,
causal_mask
,
position_ids
,
past_key_values
,
output_attentions
,
use_cache
,
cache_position
,
output_hidden_states
,
return_dict
,
)
KLlamaModel
.
dynamic_sdpa
.
calc_anchor
(
cache_position
[
-
1
]
+
1
)
KLlamaModel
.
dynamic_sdpa
.
clear_importance
(
cache_position
[
-
1
]
+
1
)
return
output
cur_idx
=
0
assert
(
output_attentions
==
False
),
"output_attentions is not supported when using chunked attention"
attn_output
=
None
# prefill
KLlamaModel
.
dynamic_sdpa
.
remaining_length
=
q_len
while
cur_idx
<
q_len
:
print
(
f
'current prefill length:
{
cur_idx
}
'
)
chunk_mask
=
None
if
inputs_embeds
.
device
.
type
==
'cpu'
:
tmp_inputs_embeds
=
inputs_embeds
[:,
cur_idx
:
min
(
cur_idx
+
chunck_size
,
q_len
)].
to
(
"cuda"
)
else
:
tmp_inputs_embeds
=
inputs_embeds
[:,
cur_idx
:
min
(
cur_idx
+
chunck_size
,
q_len
)]
output_with_past
=
self
.
forward_chunk
(
tmp_inputs_embeds
,
chunk_mask
,
position_ids
[:,
cur_idx
:
min
(
cur_idx
+
chunck_size
,
q_len
)],
past_key_values
,
output_attentions
,
use_cache
,
cache_position
[
cur_idx
:
min
(
cur_idx
+
chunck_size
,
q_len
)],
)
cur_output
=
output_with_past
.
last_hidden_state
KLlamaModel
.
dynamic_sdpa
.
remaining_length
-=
(
min
(
cur_idx
+
chunck_size
,
q_len
)
-
cur_idx
)
cur_idx
+=
chunck_size
# if attn_output is None:
attn_output
=
cur_output
# else:
# attn_output = torch.cat((attn_output, cur_output), dim=-2)
KLlamaModel
.
dynamic_sdpa
.
calc_anchor
(
cache_position
[
-
1
]
+
1
)
KLlamaModel
.
dynamic_sdpa
.
clear_importance
(
cache_position
[
-
1
]
+
1
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
attn_output
)
def
forward_chunk
(
self
,
inputs_embeds
,
causal_mask
,
position_ids
,
past_key_values
,
output_attentions
,
use_cache
,
cache_position
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
):
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
hidden_states
=
inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings
=
self
.
rotary_emb
(
hidden_states
,
position_ids
)
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
None
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
None
for
decoder_layer
in
self
.
layers
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
if
self
.
gradient_checkpointing
and
self
.
training
:
layer_outputs
=
self
.
_gradient_checkpointing_func
(
decoder_layer
.
__call__
,
hidden_states
,
causal_mask
,
position_ids
,
past_key_values
,
output_attentions
,
use_cache
,
cache_position
,
position_embeddings
,
)
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
causal_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_values
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
cache_position
=
cache_position
,
position_embeddings
=
position_embeddings
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
=
layer_outputs
[
2
if
output_attentions
else
1
]
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],)
hidden_states
=
self
.
norm
(
hidden_states
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
,
)
def
_update_causal_mask
(
self
,
attention_mask
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
cache_position
:
torch
.
Tensor
,
past_key_values
:
Cache
,
output_attentions
:
bool
,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
if
attention_mask
is
not
None
and
0.0
in
attention_mask
:
return
attention_mask
return
None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens
=
(
past_key_values
.
get_seq_length
()
if
past_key_values
is
not
None
else
0
)
using_static_cache
=
isinstance
(
past_key_values
,
StaticCache
)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if
(
self
.
config
.
_attn_implementation
==
"sdpa"
and
not
using_static_cache
and
not
output_attentions
):
if
AttentionMaskConverter
.
_ignore_causal_mask_sdpa
(
attention_mask
,
inputs_embeds
=
input_tensor
,
past_key_values_length
=
past_seen_tokens
,
is_training
=
self
.
training
,
):
return
None
dtype
,
device
=
input_tensor
.
dtype
,
input_tensor
.
device
min_dtype
=
torch
.
finfo
(
dtype
).
min
sequence_length
=
input_tensor
.
shape
[
1
]
if
using_static_cache
:
target_length
=
past_key_values
.
get_max_length
()
else
:
target_length
=
(
attention_mask
.
shape
[
-
1
]
if
isinstance
(
attention_mask
,
torch
.
Tensor
)
else
past_seen_tokens
+
sequence_length
+
1
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
4
:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if
attention_mask
.
max
()
!=
0
:
raise
ValueError
(
"Custom 4D attention mask should be passed in inverted form with max==0`"
)
causal_mask
=
attention_mask
else
:
causal_mask
=
torch
.
full
(
(
sequence_length
,
target_length
),
fill_value
=
min_dtype
,
dtype
=
dtype
,
device
=
device
,
)
if
sequence_length
!=
1
:
causal_mask
=
torch
.
triu
(
causal_mask
,
diagonal
=
1
)
causal_mask
*=
torch
.
arange
(
target_length
,
device
=
device
)
>
cache_position
.
reshape
(
-
1
,
1
)
causal_mask
=
causal_mask
[
None
,
None
,
:,
:].
expand
(
input_tensor
.
shape
[
0
],
1
,
-
1
,
-
1
)
if
attention_mask
is
not
None
:
causal_mask
=
(
causal_mask
.
clone
()
)
# copy to contiguous memory for in-place edit
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
(
causal_mask
[:,
:,
:,
:
mask_length
]
+
attention_mask
[:,
None
,
None
,
:]
)
padding_mask
=
padding_mask
==
0
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[
:,
:,
:,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
if
(
self
.
config
.
_attn_implementation
==
"sdpa"
and
attention_mask
is
not
None
and
attention_mask
.
device
.
type
==
"cuda"
and
not
output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask
=
AttentionMaskConverter
.
_unmask_unattended
(
causal_mask
,
min_dtype
)
return
causal_mask
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
View file @
4d1d561d
...
...
@@ -225,4 +225,4 @@
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:3"
prefill_device
:
"
cuda:3"
\ No newline at end of file
prefill_device
:
"
cuda:3"
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
View file @
4d1d561d
...
...
@@ -123,4 +123,4 @@
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
\ No newline at end of file
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
View file @
4d1d561d
...
...
@@ -6,7 +6,7 @@
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj
).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
...
...
@@ -41,6 +41,12 @@
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
2000
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
View file @
4d1d561d
...
...
@@ -123,4 +123,4 @@
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
\ No newline at end of file
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml
0 → 100644
View file @
4d1d561d
-
match
:
class
:
ktransformers.models.modeling_llama.LlamaRotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.RotaryEmbeddingV2
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
class
:
ktransformers.models.modeling_llama.LlamaModel
replace
:
class
:
ktransformers.operators.models.KLlamaModel
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KLlamaAttention
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
View file @
4d1d561d
...
...
@@ -109,4 +109,4 @@
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
\ No newline at end of file
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
View file @
4d1d561d
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace
:
...
...
@@ -54,4 +61,4 @@
class
:
"
default"
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
\ No newline at end of file
prefill_device
:
"
cuda"
ktransformers/server/config/config.py
View file @
4d1d561d
...
...
@@ -5,10 +5,11 @@ Description :
Author : unicornchan
Date : 2024-06-11 16:35:42
Version : 1.0.0
LastEditors :
chenxl
LastEditTime : 2024-0
7-27 01:55:42
LastEditors :
WuHao
LastEditTime : 2024-0
8-12 06:31:14
'''
import
os
import
shutil
import
yaml
from
ktransformers.server.config.singleton
import
Singleton
...
...
@@ -30,10 +31,18 @@ class Config(metaclass=Singleton):
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))
config_yaml
:
str
=
os
.
path
.
join
(
base_path
,
"configs"
,
Config
.
CONFIG_FILE_NAME
)
user_path
:
str
=
os
.
path
.
expanduser
(
'~'
)
localstore_path
:
str
=
os
.
path
.
join
(
user_path
,
'.ktransformers'
)
config_path
:
str
=
os
.
path
.
join
(
localstore_path
,
Config
.
CONFIG_FILE_NAME
)
if
not
os
.
path
.
exists
(
config_yaml
):
print
(
f
"Can't find config file,
{
config_yaml
}
"
)
exit
(
-
1
)
with
open
(
config_yaml
,
'r'
,
encoding
=
"utf-8"
)
as
fp
:
if
not
os
.
path
.
exists
(
localstore_path
):
os
.
mkdir
(
localstore_path
)
if
not
os
.
path
.
exists
(
config_path
):
shutil
.
copyfile
(
config_yaml
,
config_path
)
with
open
(
config_path
,
'r'
,
encoding
=
"utf-8"
)
as
fp
:
config
=
yaml
.
safe_load
(
fp
)
return
config
...
...
@@ -51,6 +60,8 @@ class Config(metaclass=Singleton):
cfg
=
Config
.
load
()
self
.
base_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
)))
self
.
user_path
:
str
=
os
.
path
.
expanduser
(
'~'
)
self
.
localstore_path
:
str
=
os
.
path
.
join
(
self
.
user_path
,
'.ktransformers'
)
# log configs
self
.
log_dir
=
os
.
path
.
join
(
self
.
base_path
,
Config
.
to_path
(
cfg
[
"log"
][
"dir"
]))
self
.
log_file
=
cfg
[
"log"
][
"file"
]
...
...
@@ -83,11 +94,20 @@ class Config(metaclass=Singleton):
self
.
model_name
:
str
=
self
.
model
.
get
(
"name"
,
""
)
self
.
model_device
:
str
=
self
.
model
.
get
(
"device"
,
"cuda:0"
)
self
.
gguf_path
:
str
=
self
.
model
.
get
(
"gguf_path"
,
""
)
self
.
model_cache_lens
=
self
.
model
.
get
(
"cache_lens"
)
# web config
self
.
web
:
dict
=
cfg
.
get
(
"web"
,
{})
self
.
web_cross_domain
:
bool
=
self
.
web
.
get
(
"open_cross_domain"
,
True
)
self
.
mount_web
:
bool
=
self
.
web
.
get
(
"mount"
,
False
)
self
.
ext
:
dict
=
cfg
.
get
(
"ext"
,
{})
self
.
cpu_infer
=
self
.
ext
.
get
(
"cpu_infer"
,
10
)
#file config
self
.
local_store_configs
:
dict
=
cfg
.
get
(
"local_store"
,{})
self
.
file_upload_dir
:
str
=
os
.
path
.
join
(
self
.
localstore_path
,
self
.
local_store_configs
.
get
(
"file_upload_dir"
,
""
))
self
.
assistant_store_dir
:
str
=
os
.
path
.
join
(
self
.
localstore_path
,
self
.
local_store_configs
.
get
(
"assistant_store_dir"
,
""
))
#long context config
self
.
long_context_config
:
dict
=
cfg
.
get
(
"long_context"
,{})
\ No newline at end of file
ktransformers/util/cuda_graph_runner.py
View file @
4d1d561d
...
...
@@ -46,7 +46,8 @@ class CUDAGraphRunner:
capture_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
torch
.
cuda
.
set_device
(
main_device
)
torch
.
cuda
.
set_stream
(
capture_stream
)
past_key_values
.
change_seq_length
(
-
1
)
if
past_key_values
!=
None
:
past_key_values
.
change_seq_length
(
-
1
)
torch
.
cuda
.
synchronize
(
self
.
main_device
)
#self.graph.debug_dump("cuda_graph_hooked.dot")
...
...
ktransformers/util/custom_gguf.py
View file @
4d1d561d
...
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54
Version : 1.0.0
LastEditors : kkk1nak0
LastEditTime : 2024-08-1
2
0
7
:2
1:5
5
LastEditTime : 2024-08-1
4
0
8
:2
0:4
5
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
...
...
@@ -294,7 +294,6 @@ class GGUFLoader:
else
:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
)
values
=
values
.
view
(
shape
[::
-
1
])
if
"attn_q"
in
name
and
self
.
gguf_file_meta
[
'general.architecture'
]
in
[
"llama"
]:
n_head
=
self
.
gguf_file_meta
[
'llama.attention.head_count'
]
...
...
ktransformers/util/utils.py
View file @
4d1d561d
...
...
@@ -84,7 +84,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
else
:
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
):
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
):
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
torch
.
_dynamo
.
config
.
suppress_errors
=
True
...
...
@@ -110,7 +111,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
]
past_key_values
.
change_seq_length
(
1
)
if
past_key_values
!=
None
:
past_key_values
.
change_seq_length
(
1
)
for
device
in
all_cuda_device
:
torch
.
cuda
.
synchronize
(
device
)
#print(logits)
...
...
@@ -125,18 +127,26 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch
.
cuda
.
set_device
(
torch_device
)
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
past_key_values
.
cur_idx
=
cache_position
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
start_time
=
time
.
time
()
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
if
mode
==
"long_context"
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
...
...
@@ -184,7 +194,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens
.
append
(
next_token
.
int
())
seq_length
+=
1
if
next_token
[
0
].
item
()
==
tokenizer
.
eos_token_id
:
if
next_token
[
0
].
item
()
==
tokenizer
.
eos_token_id
or
tokenizer
.
decode
(
next_token
)
==
'<|im_end|>'
:
print
(
stream
.
end
(),
end
=
""
,
flush
=
True
)
break
else
:
...
...
pyproject.toml
View file @
4d1d561d
...
...
@@ -27,7 +27,8 @@ dependencies = [
"wheel"
,
"colorlog"
,
"build"
,
"fire"
"fire"
,
"protobuf"
]
requires-python
=
">=3.10"
...
...
requirements-local_chat.txt
View file @
4d1d561d
...
...
@@ -3,4 +3,5 @@ transformers
numpy
torch>=2.3.0
packaging
cpufeature
\ No newline at end of file
cpufeature
protobuf
\ No newline at end of file
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment