Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
b443c7df
Unverified
Commit
b443c7df
authored
Feb 25, 2025
by
Atream
Committed by
GitHub
Feb 25, 2025
Browse files
Merge pull request #657 from kvcache-ai/feat-absorb-for-long-prefill
Feat absorb for long prefill
parents
050b745a
f4c198bd
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
193 additions
and
43 deletions
+193
-43
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-3
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+39
-19
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+3
-3
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+22
-8
ktransformers/operators/models.py
ktransformers/operators/models.py
+9
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+1
-0
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
+86
-0
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+5
-0
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+1
-1
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+2
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+23
-5
No files found.
ktransformers/local_chat.py
View file @
b443c7df
...
...
@@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
...
...
@@ -64,7 +64,6 @@ def local_chat(
force_think
:
bool
=
False
,
):
torch
.
set_grad_enabled
(
False
)
Config
().
cpu_infer
=
cpu_infer
...
...
@@ -169,7 +168,7 @@ def local_chat(
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
...
...
ktransformers/operators/attention.py
View file @
b443c7df
...
...
@@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.utils
import
get_compute_capability
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
...
...
@@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
absorb_for_prefill
:
bool
=
False
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
mla_wrapper
=
None
self
.
absorb_for_prefill
=
absorb_for_prefill
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
...
...
@@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
1
,
2
)
assert
q_nope
.
is_contiguous
()
#
assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
...
...
@@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
@@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if
q_len
==
1
:
if
q_len
==
1
or
self
.
absorb_for_prefill
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv_with_k_pe
,
page_table
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -395,27 +399,41 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
1
,
2
)
assert
q_nope
.
is_contiguous
()
q_nope
=
q_nope
.
contiguous
()
#assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope
.
squeeze_
(
1
)
q_pe
.
squeeze_
(
1
)
q_nope
.
squeeze_
(
0
)
q_pe
.
squeeze_
(
0
)
# flash attn doesn't support head_dim bigger than 256, use flashinfer
if
self
.
mla_wrapper
is
None
:
self
.
mla_wrapper
=
MLAWrapperSingleton
.
get_instance
(
self
.
device
,
1
,
past_key_value
.
max_pages
,
use_cuda_graph
=
True
)
if
self
.
mla_wrapper
.
need_plan
:
self
.
mla_wrapper
.
need_plan
=
False
if
self
.
mla_wrapper
.
need_plan
:
self
.
mla_wrapper
.
need_plan
=
False
if
q_len
==
1
:
self
.
mla_wrapper
.
plan
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
position_ids
.
squeeze
(
1
)
+
1
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
else
:
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
kv_len_arr
=
torch
.
tensor
([
position_ids
[
0
,
-
1
].
item
()
+
1
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
mla_wrapper
.
plan
(
qo_indptr
,
None
,
None
,
kv_len_arr
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
attn_output
=
self
.
mla_wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
"""
...
...
@@ -441,10 +459,11 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
attn_output
.
transpose
(
1
,
2
)
# [bsz, self.num_heads, q_len, self.kv_lora_rank]
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
# [bsz, self.num_heads, q_len, self.v_head_dim]
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
# [bsz, q_len, self.num_heads, self.kv_lora_rank]
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
# [bsz, q_len, self.num_heads * self.v_head_dim]
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
...
...
@@ -571,7 +590,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
return
self
.
forward_windows
(
hidden_states
,
attention_mask
,
...
...
ktransformers/operators/experts.py
View file @
b443c7df
...
...
@@ -159,7 +159,7 @@ class KExpertsCPU(KExpertsBase):
down_ptr
=
ctypes
.
addressof
(
ctypes
.
cast
(
self
.
down
.
ctypes
.
data
,
ctypes
.
POINTER
(
ctypes
.
c_uint64
)).
contents
)
#
print(self.gate_
q
type, self.up_
q
type, self.down_
q
type)
#print(self.gate_type, self.up_type, self.down_type)
n_routed_experts
=
self
.
n_routed_experts
# n_routed_experts = len(self.orig_module)
moe_config
=
MOEConfig
(
...
...
@@ -459,9 +459,9 @@ class KExpertsTorch(KExpertsBase):
self
.
up
[
i
]
=
w
[
"up"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
down
[
i
]
=
w
[
"down"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
up
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
self
.
up
=
torch
.
cat
(
self
.
up
,
dim
=
0
)
self
.
gate
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
self
.
down
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
self
.
down
=
torch
.
cat
(
self
.
down
,
dim
=
0
)
return
def
unload
(
self
):
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
b443c7df
...
...
@@ -9,7 +9,7 @@ flashinfer_enabled = False
try
:
import
flashinfer
flashinfer_enabled
=
False
# disabled now, TODO:use new version of flashinfer and enabl
e
flashinfer_enabled
=
Tru
e
print
(
"found flashinfer"
)
except
ImportError
:
...
...
@@ -132,14 +132,14 @@ class MLAWrapper():
head_dim_ckv
,
head_dim_kpe
,
page_size
,
Fals
e
,
# causal
is False for decoding
Tru
e
,
# causal
sm_scale
,
q_data_type
,
kv_data_type
,
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
class
MLAWrapperSingleton
():
wrappers
:
dict
=
{}
...
...
@@ -179,6 +179,17 @@ class MLAWrapperSingleton():
sm_scale
,
q_data_type
,
kv_data_type
,)
wrapper
.
need_plan
=
False
@
classmethod
def
need_plan_all
(
cls
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
need_plan
=
True
@
classmethod
def
reset_buffer
(
cls
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
qo_indptr_buf
[
1
]
=
1
if
__name__
==
"__main__"
:
...
...
@@ -187,8 +198,9 @@ if __name__ == "__main__":
page_size
=
64
num_heads
=
128
q_nope
=
torch
.
randn
((
1
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
1
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_len
=
10
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
=
torch
.
randn
((
max_pages
,
page_size
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k_pe
=
torch
.
randn
((
max_pages
,
page_size
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
...
...
@@ -199,10 +211,10 @@ if __name__ == "__main__":
max_pages
,
)
kv_len_arr
=
torch
.
tensor
([
10
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_len_arr
=
torch
.
tensor
([
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
None
,
qo_indptr
,
None
,
None
,
kv_len_arr
,
...
...
@@ -216,6 +228,7 @@ if __name__ == "__main__":
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
...
...
@@ -235,6 +248,7 @@ if __name__ == "__main__":
False
,
192
**
(
-
0.5
)
)
print
(
attn_ref
.
shape
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
"test past"
)
\ No newline at end of file
ktransformers/operators/models.py
View file @
b443c7df
...
...
@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.util.utils
import
InferenceState
,
get_compute_capability
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers.configuration_utils
import
PretrainedConfig
from
ktransformers.models.modeling_llama
import
(
...
...
@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
causal_mask
=
None
else
:
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
)
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
)
else
:
causal_mask
=
None
# embed positions
hidden_states
=
inputs_embeds
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
b443c7df
...
...
@@ -60,6 +60,7 @@
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
...
...
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
0 → 100644
View file @
b443c7df
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
#- match:
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# kwargs:
# prefill_device: "cuda"
# prefill_op: "KExpertsTorch"
# generate_device: "cuda"
# generate_op: "KExpertsMarlin"
# recursive: False # don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/server/backend/interfaces/ktransformers.py
View file @
b443c7df
...
...
@@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
warm_uped
=
False
...
...
@@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
...
...
@@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
b443c7df
...
...
@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
if
i
>
1
and
flashinfer_enabled
:
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
...
...
ktransformers/util/custom_gguf.py
View file @
b443c7df
...
...
@@ -330,6 +330,8 @@ class GGUFLoader:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
.
copy
())
if
ggml_name
==
"BF16"
:
values
=
values
.
view
(
torch
.
bfloat16
)
values
=
values
.
view
(
shape
[
-
2
::
-
1
])
return
values
...
...
ktransformers/util/utils.py
View file @
b443c7df
...
...
@@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped
=
False
def
get_compute_capability
(
device
:
torch
.
device
=
None
):
if
torch
.
cuda
.
is_available
():
if
device
is
None
:
num_gpus
=
torch
.
cuda
.
device_count
()
min_compute_capability_major
=
100
for
gpu_id
in
range
(
num_gpus
):
gpu_props
=
torch
.
cuda
.
get_device_properties
(
gpu_id
)
min_compute_capability_major
=
min
(
min_compute_capability_major
,
gpu_props
.
major
)
return
min_compute_capability_major
else
:
return
torch
.
cuda
.
get_device_properties
(
device
)
def
set_module
(
model
,
submodule_key
,
module
):
tokens
=
submodule_key
.
split
(
'.'
)
sub_tokens
=
tokens
[:
-
1
]
...
...
@@ -164,6 +176,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
need_plan_all
()
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
...
...
@@ -186,6 +201,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
first_token_time
=
time
.
time
()
-
start_time
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
reset_buffer
()
prefill_count
=
seq_length
prefill_time
=
first_token_time
...
...
@@ -203,22 +221,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
start_time
=
time
.
time
()
for
i
in
range
(
1
,
max_new_tokens
):
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
global
warm_uped
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
if
i
>
1
and
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
tokens
.
append
(
int
(
next_token
))
seq_length
+=
1
if
next_token
[
0
].
item
()
==
tokenizer
.
eos_token_id
or
tokenizer
.
decode
(
next_token
)
==
'<|im_end|>'
:
if
next_token
[
0
].
item
()
==
tokenizer
.
eos_token_id
or
tokenizer
.
decode
(
next_token
.
tolist
()
)
==
'<|im_end|>'
:
print
(
stream
.
end
(),
end
=
""
,
flush
=
True
)
break
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment