Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
85e2cc7b
Unverified
Commit
85e2cc7b
authored
Feb 27, 2025
by
Atream
Committed by
GitHub
Feb 27, 2025
Browse files
Merge pull request #719 from kvcache-ai/fix-use-generation-json
use generation config from json file in official repo
parents
5e3c6b4f
e645d847
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
21 deletions
+57
-21
ktransformers/local_chat.py
ktransformers/local_chat.py
+9
-9
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+30
-1
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+15
-9
ktransformers/util/utils.py
ktransformers/util/utils.py
+3
-2
No files found.
ktransformers/local_chat.py
View file @
85e2cc7b
...
@@ -111,11 +111,11 @@ def local_chat(
...
@@ -111,11 +111,11 @@ def local_chat(
try
:
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
except
:
except
Exception
as
e
:
print
(
f
"generation config can't auto create, make default. Message:
{
e
}
"
)
gen_config
=
GenerationConfig
(
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.6
,
temperature
=
0.7
,
top_p
=
0.95
,
top_p
=
0.9
,
do_sample
=
True
do_sample
=
True
)
)
model
.
generation_config
=
gen_config
model
.
generation_config
=
gen_config
...
...
ktransformers/operators/attention.py
View file @
85e2cc7b
...
@@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward_linux_flashinfer
(
def
forward_linux_flashinfer
_chunk
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward_linux_flashinfer
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
q_len
<=
self
.
chunck_size
or
not
self
.
absorb_for_prefill
:
return
self
.
forward_linux_flashinfer_chunk
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
,
)
assert
False
def
forward_windows
(
def
forward_windows
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
85e2cc7b
...
@@ -139,6 +139,11 @@ class MLAWrapper():
...
@@ -139,6 +139,11 @@ class MLAWrapper():
)
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
class
MLAWrapperSingleton
():
class
MLAWrapperSingleton
():
...
@@ -201,11 +206,12 @@ class MLAWrapperSingleton():
...
@@ -201,11 +206,12 @@ class MLAWrapperSingleton():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
max_batch_size
=
1
max_batch_size
=
1
max_pages
=
1
max_pages
=
1
28
page_size
=
64
page_size
=
64
num_heads
=
128
num_heads
=
128
q_len
=
10
kv_len
=
2069
q_len
=
1
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
=
torch
.
randn
((
max_pages
,
page_size
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
=
torch
.
randn
((
max_pages
,
page_size
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
...
@@ -218,7 +224,7 @@ if __name__ == "__main__":
...
@@ -218,7 +224,7 @@ if __name__ == "__main__":
max_pages
,
max_pages
,
)
)
kv_len_arr
=
torch
.
tensor
([
q
_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_len_arr
=
torch
.
tensor
([
kv
_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
wrapper
.
plan
(
qo_indptr
,
qo_indptr
,
...
@@ -244,15 +250,15 @@ if __name__ == "__main__":
...
@@ -244,15 +250,15 @@ if __name__ == "__main__":
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
print
(
k
[:
10
].
shape
)
print
(
k
[:
kv_len
].
shape
)
print
(
v
[:
10
].
shape
)
print
(
v
[:
kv_len
].
shape
)
attn_ref
,
lse_ref
=
attention_ref
(
attn_ref
,
lse_ref
=
attention_ref
(
max_batch_size
,
max_batch_size
,
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
),
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
),
k
[:
10
],
k
[:
kv_len
],
v
[:
10
],
v
[:
kv_len
],
Fals
e
,
Tru
e
,
192
**
(
-
0.5
)
192
**
(
-
0.5
)
)
)
print
(
attn_ref
.
shape
)
print
(
attn_ref
.
shape
)
...
...
ktransformers/util/utils.py
View file @
85e2cc7b
...
@@ -183,8 +183,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -183,8 +183,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
None
,
max_length
=
max_new_tokens
,
None
,
do_sample
=
True
do_sample
=
True
,
top_k
=
5
,
top_p
=
0.85
,
temperature
=
0.1
# change this to modify generate config
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
)
)
try
:
# transformers==4.43
try
:
# transformers==4.43
logits_warper
=
(
logits_warper
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment