Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
f4c198bd
Commit
f4c198bd
authored
Feb 25, 2025
by
Atream
Browse files
support absorb for prefill long context
parent
e9b1216a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
93 additions
and
33 deletions
+93
-33
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-2
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+36
-16
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+22
-8
ktransformers/operators/models.py
ktransformers/operators/models.py
+4
-2
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+1
-0
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+5
-0
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+1
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+22
-4
No files found.
ktransformers/local_chat.py
View file @
f4c198bd
...
...
@@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
...
...
@@ -168,7 +168,7 @@ def local_chat(
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
...
...
ktransformers/operators/attention.py
View file @
f4c198bd
...
...
@@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.utils
import
get_compute_capability
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
...
...
@@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
absorb_for_prefill
:
bool
=
False
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
mla_wrapper
=
None
self
.
absorb_for_prefill
=
absorb_for_prefill
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
...
...
@@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
1
,
2
)
assert
q_nope
.
is_contiguous
()
#
assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
...
...
@@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
@@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if
q_len
==
1
:
if
q_len
==
1
or
self
.
absorb_for_prefill
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv_with_k_pe
,
page_table
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -395,27 +399,41 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
1
,
2
)
assert
q_nope
.
is_contiguous
()
q_nope
=
q_nope
.
contiguous
()
#assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope
.
squeeze_
(
1
)
q_pe
.
squeeze_
(
1
)
q_nope
.
squeeze_
(
0
)
q_pe
.
squeeze_
(
0
)
# flash attn doesn't support head_dim bigger than 256, use flashinfer
if
self
.
mla_wrapper
is
None
:
self
.
mla_wrapper
=
MLAWrapperSingleton
.
get_instance
(
self
.
device
,
1
,
past_key_value
.
max_pages
,
use_cuda_graph
=
True
)
if
self
.
mla_wrapper
.
need_plan
:
self
.
mla_wrapper
.
need_plan
=
False
if
self
.
mla_wrapper
.
need_plan
:
self
.
mla_wrapper
.
need_plan
=
False
if
q_len
==
1
:
self
.
mla_wrapper
.
plan
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
position_ids
.
squeeze
(
1
)
+
1
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
else
:
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
kv_len_arr
=
torch
.
tensor
([
position_ids
[
0
,
-
1
].
item
()
+
1
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
mla_wrapper
.
plan
(
qo_indptr
,
None
,
None
,
kv_len_arr
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
attn_output
=
self
.
mla_wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
"""
...
...
@@ -443,6 +461,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
# [bsz, self.num_heads, q_len, self.kv_lora_rank]
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
# [bsz, self.num_heads, q_len, self.v_head_dim]
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
# [bsz, q_len, self.num_heads, self.kv_lora_rank]
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
# [bsz, q_len, self.num_heads * self.v_head_dim]
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
@@ -571,7 +590,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
return
self
.
forward_windows
(
hidden_states
,
attention_mask
,
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
f4c198bd
...
...
@@ -9,7 +9,7 @@ flashinfer_enabled = False
try
:
import
flashinfer
flashinfer_enabled
=
False
# disabled now, TODO:use new version of flashinfer and enabl
e
flashinfer_enabled
=
Tru
e
print
(
"found flashinfer"
)
except
ImportError
:
...
...
@@ -132,14 +132,14 @@ class MLAWrapper():
head_dim_ckv
,
head_dim_kpe
,
page_size
,
Fals
e
,
# causal
is False for decoding
Tru
e
,
# causal
sm_scale
,
q_data_type
,
kv_data_type
,
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
class
MLAWrapperSingleton
():
wrappers
:
dict
=
{}
...
...
@@ -179,6 +179,17 @@ class MLAWrapperSingleton():
sm_scale
,
q_data_type
,
kv_data_type
,)
wrapper
.
need_plan
=
False
@
classmethod
def
need_plan_all
(
cls
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
need_plan
=
True
@
classmethod
def
reset_buffer
(
cls
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
qo_indptr_buf
[
1
]
=
1
if
__name__
==
"__main__"
:
...
...
@@ -187,8 +198,9 @@ if __name__ == "__main__":
page_size
=
64
num_heads
=
128
q_nope
=
torch
.
randn
((
1
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
1
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_len
=
10
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
=
torch
.
randn
((
max_pages
,
page_size
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k_pe
=
torch
.
randn
((
max_pages
,
page_size
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
...
...
@@ -199,10 +211,10 @@ if __name__ == "__main__":
max_pages
,
)
kv_len_arr
=
torch
.
tensor
([
10
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_len_arr
=
torch
.
tensor
([
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
None
,
qo_indptr
,
None
,
None
,
kv_len_arr
,
...
...
@@ -216,6 +228,7 @@ if __name__ == "__main__":
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
...
...
@@ -235,6 +248,7 @@ if __name__ == "__main__":
False
,
192
**
(
-
0.5
)
)
print
(
attn_ref
.
shape
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
"test past"
)
\ No newline at end of file
ktransformers/operators/models.py
View file @
f4c198bd
...
...
@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.util.utils
import
InferenceState
,
get_compute_capability
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers.configuration_utils
import
PretrainedConfig
from
ktransformers.models.modeling_llama
import
(
...
...
@@ -649,7 +649,9 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
causal_mask
=
None
else
:
if
os
.
name
==
'nt'
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
)
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
f4c198bd
...
...
@@ -60,6 +60,7 @@
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
f4c198bd
...
...
@@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
warm_uped
=
False
...
...
@@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
...
...
@@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
f4c198bd
...
...
@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
if
i
>
1
and
flashinfer_enabled
:
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
...
...
ktransformers/util/utils.py
View file @
f4c198bd
...
...
@@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped
=
False
def
get_compute_capability
(
device
:
torch
.
device
=
None
):
if
torch
.
cuda
.
is_available
():
if
device
is
None
:
num_gpus
=
torch
.
cuda
.
device_count
()
min_compute_capability_major
=
100
for
gpu_id
in
range
(
num_gpus
):
gpu_props
=
torch
.
cuda
.
get_device_properties
(
gpu_id
)
min_compute_capability_major
=
min
(
min_compute_capability_major
,
gpu_props
.
major
)
return
min_compute_capability_major
else
:
return
torch
.
cuda
.
get_device_properties
(
device
)
def
set_module
(
model
,
submodule_key
,
module
):
tokens
=
submodule_key
.
split
(
'.'
)
sub_tokens
=
tokens
[:
-
1
]
...
...
@@ -153,6 +165,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
need_plan_all
()
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
...
...
@@ -175,6 +190,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
first_token_time
=
time
.
time
()
-
start_time
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
reset_buffer
()
prefill_count
=
seq_length
prefill_time
=
first_token_time
...
...
@@ -192,15 +210,15 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
start_time
=
time
.
time
()
for
i
in
range
(
1
,
max_new_tokens
):
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
global
warm_uped
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
if
i
>
1
and
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment