Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
6735beb5
Commit
6735beb5
authored
Aug 29, 2024
by
TangJingqi
Browse files
Fix cannot offload whole layer in cpu
parent
35d7aed2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
11 deletions
+14
-11
ktransformers/local_chat.py
ktransformers/local_chat.py
+4
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+3
-3
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+2
-2
ktransformers/operators/models.py
ktransformers/operators/models.py
+5
-4
No files found.
ktransformers/local_chat.py
View file @
6735beb5
...
@@ -67,6 +67,7 @@ def local_chat(
...
@@ -67,6 +67,7 @@ def local_chat(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
mode
==
'long_context'
:
if
mode
==
'long_context'
:
assert
config
.
architectures
[
0
]
==
"LlamaForCausalLM"
,
"only LlamaForCausalLM support long_context mode"
torch
.
set_default_dtype
(
torch
.
float16
)
torch
.
set_default_dtype
(
torch
.
float16
)
else
:
else
:
torch
.
set_default_dtype
(
config
.
torch_dtype
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
...
@@ -143,8 +144,9 @@ def local_chat(
...
@@ -143,8 +144,9 @@ def local_chat(
input_tensor
=
tokenizer
.
apply_chat_template
(
input_tensor
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
)
)
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
if
mode
==
'long_context'
:
"please change max_seq_len in ~/.ktransformers/config.yaml"
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
torch
.
set_default_dtype
(
torch
.
set_default_dtype
(
torch
.
bfloat16
torch
.
bfloat16
)
# TODO: Remove this, replace dtype using config
)
# TODO: Remove this, replace dtype using config
...
...
ktransformers/operators/experts.py
View file @
6735beb5
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-25 11:25:24
Date : 2024-07-25 11:25:24
Version : 0.1.0
Version : 0.1.0
LastEditors : Azure
LastEditors : Azure
LastEditTime : 2024-08-2
7
0
3:50:23
LastEditTime : 2024-08-2
9
0
9:41:10
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
...
@@ -202,7 +202,7 @@ class KExpertsCPU(KExpertsBase):
...
@@ -202,7 +202,7 @@ class KExpertsCPU(KExpertsBase):
def
forward
(
self
,
input_tensor
,
expert_ids
,
weights
):
def
forward
(
self
,
input_tensor
,
expert_ids
,
weights
):
# generate, capture and run cuda graph
# generate, capture and run cuda graph
# print(expert_ids)
# print(expert_ids)
if
input_tensor
.
size
(
0
)
==
1
:
if
input_tensor
.
size
(
0
)
==
1
and
torch
.
cuda
.
is_current_stream_capturing
()
:
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts")
#print("capturing experts")
KExpertsCPU
.
input_tensor_cpu
.
copy_
(
input_tensor
,
non_blocking
=
True
)
KExpertsCPU
.
input_tensor_cpu
.
copy_
(
input_tensor
,
non_blocking
=
True
)
...
@@ -636,7 +636,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
...
@@ -636,7 +636,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
topk_idx
,
topk_weight
,
aux_loss
=
self
.
gate
(
hidden_states
)
topk_idx
,
topk_weight
,
aux_loss
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
):
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
()
:
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
...
...
ktransformers/operators/linear.py
View file @
6735beb5
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang
Date : 2024-07-25 11:25:24
Date : 2024-07-25 11:25:24
Version : 0.1.0
Version : 0.1.0
LastEditors : Azure
LastEditors : Azure
LastEditTime : 2024-08-
14 14:57:04
LastEditTime : 2024-08-
29 09:11:16
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
...
@@ -277,7 +277,7 @@ class KLinearCPUInfer(KLinearBase):
...
@@ -277,7 +277,7 @@ class KLinearCPUInfer(KLinearBase):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
origin_shape
=
x
.
shape
# [batch_size, q_len, hidden_size]
origin_shape
=
x
.
shape
# [batch_size, q_len, hidden_size]
if
origin_shape
[
1
]
==
1
:
if
origin_shape
[
1
]
==
1
and
torch
.
cuda
.
is_current_stream_capturing
()
:
out_device
=
x
.
device
out_device
=
x
.
device
self
.
input_tensor_cpu
.
copy_
(
x
,
non_blocking
=
True
)
self
.
input_tensor_cpu
.
copy_
(
x
,
non_blocking
=
True
)
qlen
=
origin_shape
[
1
]
qlen
=
origin_shape
[
1
]
...
...
ktransformers/operators/models.py
View file @
6735beb5
...
@@ -670,11 +670,12 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -670,11 +670,12 @@ class KDeepseekV2Model(BaseInjectedModule):
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
cur_device
=
self
.
transfer_map
[
i
]
if
cur_device
not
in
self
.
stream_device_map
:
if
cur_device
not
in
self
.
stream_device_map
and
cur_device
.
lower
()
!=
"cpu"
:
self
.
stream_device_map
[
cur_device
]
=
torch
.
cuda
.
Stream
(
cur_device
)
self
.
stream_device_map
[
cur_device
]
=
torch
.
cuda
.
Stream
(
cur_device
)
torch
.
cuda
.
set_device
(
cur_device
)
if
cur_device
.
lower
()
!=
"cpu"
:
self
.
stream_device_map
[
cur_device
].
wait_stream
(
prev_stream
)
torch
.
cuda
.
set_device
(
cur_device
)
torch
.
cuda
.
set_stream
(
self
.
stream_device_map
[
cur_device
])
self
.
stream_device_map
[
cur_device
].
wait_stream
(
prev_stream
)
torch
.
cuda
.
set_stream
(
self
.
stream_device_map
[
cur_device
])
hidden_states
=
hidden_states
.
to
(
hidden_states
=
hidden_states
.
to
(
self
.
transfer_map
[
i
],
non_blocking
=
True
self
.
transfer_map
[
i
],
non_blocking
=
True
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment