Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
89f8218a
Unverified
Commit
89f8218a
authored
Feb 19, 2025
by
Atream
Committed by
GitHub
Feb 19, 2025
Browse files
Merge pull request #487 from kvcache-ai/clean_pr
clean PR code and disable flashinfer
parents
cf4da5fd
a5295183
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
23 deletions
+13
-23
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+7
-17
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+1
-1
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+5
-5
No files found.
ktransformers/operators/attention.py
View file @
89f8218a
...
@@ -58,18 +58,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -58,18 +58,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
qk_nope_head_dim
,
bias
=
False
,
dtype
=
q_absorb
.
dtype
,
device
=
q_absorb
.
device
)
return
self
.
q_absorb
,
self
.
out_absorb
self
.
q_absorb
.
weight
.
data
=
q_absorb
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
#del self.orig_module.kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
def
forward_chunck
(
def
forward_chunck
(
self
,
self
,
...
@@ -105,7 +97,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -105,7 +97,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
f
"The cache structure has changed since
transformer
version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
"with a layer index."
)
)
...
@@ -129,8 +121,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -129,8 +121,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
# if hasattr(self.orig_module, 'kv_b_proj'):
# del self.orig_module.kv_b_proj
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
...
@@ -227,7 +217,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -227,7 +217,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
f
"The cache structure has changed since
transformer
version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
"with a layer index."
)
)
...
@@ -379,7 +369,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -379,7 +369,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
f
"The cache structure has changed since version
transformer verision
v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
"with a layer index."
)
)
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
89f8218a
...
@@ -9,7 +9,7 @@ flashinfer_enabled = False
...
@@ -9,7 +9,7 @@ flashinfer_enabled = False
try
:
try
:
import
flashinfer
import
flashinfer
flashinfer_enabled
=
Tru
e
flashinfer_enabled
=
False
# disabled now, TODO:use new version of flashinfer and enabl
e
print
(
"found flashinfer"
)
print
(
"found flashinfer"
)
except
ImportError
:
except
ImportError
:
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
89f8218a
...
@@ -381,13 +381,13 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -381,13 +381,13 @@ class TransformersInterface(BackendInterfaceBase):
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
# output think token after prefill done
# output think token after prefill done
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment