Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
d9b2895b
Commit
d9b2895b
authored
Feb 25, 2025
by
Atream
Browse files
Merge branch 'fix-update-flashinfer_wrapper_local_chat' into develop-0.2.2
parents
7e5962af
477ac28a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
4 deletions
+15
-4
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+1
-2
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+9
-2
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
...optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
+4
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+1
-0
No files found.
ktransformers/operators/attention.py
View file @
d9b2895b
...
...
@@ -435,7 +435,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope
.
dtype
,
compressed_kv
.
dtype
)
attn_output
=
self
.
mla_wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
"""
k = (
torch.cat([compressed_kv, k_pe], dim=-1)
...
...
@@ -465,7 +464,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
# [bsz, q_len, self.num_heads * self.v_head_dim]
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
else
:
if
past_key_value
is
not
None
:
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
d9b2895b
...
...
@@ -122,7 +122,7 @@ class MLAWrapper():
if
kv_indices
is
None
:
assert
self
.
max_batch_size
==
1
kv_indices
=
self
.
kv_indices_buf
self
.
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
...
...
@@ -189,7 +189,14 @@ class MLAWrapperSingleton():
@
classmethod
def
reset_buffer
(
cls
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
qo_indptr_buf
[
1
]
=
1
wrapper
.
qo_indptr_buf
[
1
]
=
1
# assert max_batch_size=1 here.
@
classmethod
def
update_buffer
(
cls
,
max_pages
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
kv_indptr_buf
[
1
]
=
max_pages
# assert max_batch_size=1 here.
wrapper
.
kv_indices_buf
=
torch
.
arange
(
0
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
wrapper
.
wrapper
.
_kv_indices_buf
=
wrapper
.
kv_indices_buf
if
__name__
==
"__main__"
:
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
View file @
d9b2895b
...
...
@@ -293,6 +293,7 @@
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
absorb_for_prefill
:
False
# GPU 1: layers 15–29
-
match
:
...
...
@@ -302,6 +303,7 @@
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
absorb_for_prefill
:
False
# GPU 2: layers 30–44
-
match
:
...
...
@@ -311,6 +313,7 @@
kwargs
:
generate_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
absorb_for_prefill
:
False
# GPU 3: layers 45–60
-
match
:
...
...
@@ -320,6 +323,7 @@
kwargs
:
generate_device
:
"
cuda:3"
prefill_device
:
"
cuda:3"
absorb_for_prefill
:
False
# === Overall Model Replacement with Transfer Map ===
...
...
ktransformers/util/utils.py
View file @
d9b2895b
...
...
@@ -177,6 +177,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
update_buffer
(
past_key_values
.
max_pages
)
MLAWrapperSingleton
.
need_plan_all
()
logits
=
model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment