Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
f536a708
Unverified
Commit
f536a708
authored
Aug 29, 2024
by
UnicornChan
Committed by
GitHub
Aug 29, 2024
Browse files
Merge pull request #62 from Azure-Tang/main
[Fix] Fix problem that ktransformers cannot offload whole layer in cpu
parents
35d7aed2
8747c099
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
48 additions
and
45 deletions
+48
-45
Dockerfile
Dockerfile
+1
-1
ktransformers/__init__.py
ktransformers/__init__.py
+3
-3
ktransformers/local_chat.py
ktransformers/local_chat.py
+4
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+3
-3
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+2
-2
ktransformers/operators/models.py
ktransformers/operators/models.py
+5
-4
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
...optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
+28
-28
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
+2
-2
No files found.
Dockerfile
View file @
f536a708
...
...
@@ -25,7 +25,7 @@ rm -rf /var/lib/apt/lists/* &&
cd ktransformers &&
git submodule init &&
git submodule update &&
pip install ninja pyproject numpy &&
pip install ninja pyproject numpy
cpufeature
&&
pip install flash-attn &&
CPU_INSTRUCT=NATIVE KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9" pip install . --no-build-isolation --verbose &&
pip cache purge
...
...
ktransformers/__init__.py
View file @
f536a708
...
...
@@ -5,7 +5,7 @@ Description :
Author : kkk1nak0
Date : 2024-08-15 07:34:46
Version : 1.0.0
LastEditors :
chenxl
LastEditTime : 2024-08-2
8 15:19:03
LastEditors :
Azure-Tang
LastEditTime : 2024-08-2
9 22:35:51
'''
__version__
=
"0.1.
3
"
__version__
=
"0.1.
4
"
\ No newline at end of file
ktransformers/local_chat.py
View file @
f536a708
...
...
@@ -67,6 +67,7 @@ def local_chat(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
mode
==
'long_context'
:
assert
config
.
architectures
[
0
]
==
"LlamaForCausalLM"
,
"only LlamaForCausalLM support long_context mode"
torch
.
set_default_dtype
(
torch
.
float16
)
else
:
torch
.
set_default_dtype
(
config
.
torch_dtype
)
...
...
@@ -143,6 +144,7 @@ def local_chat(
input_tensor
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
)
if
mode
==
'long_context'
:
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
torch
.
set_default_dtype
(
...
...
ktransformers/operators/experts.py
View file @
f536a708
...
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-25 11:25:24
Version : 0.1.0
LastEditors : Azure
LastEditTime : 2024-08-2
7
0
3:50:23
LastEditTime : 2024-08-2
9
0
9:41:10
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
...
...
@@ -202,7 +202,7 @@ class KExpertsCPU(KExpertsBase):
def
forward
(
self
,
input_tensor
,
expert_ids
,
weights
):
# generate, capture and run cuda graph
# print(expert_ids)
if
input_tensor
.
size
(
0
)
==
1
:
if
input_tensor
.
size
(
0
)
==
1
and
torch
.
cuda
.
is_current_stream_capturing
()
:
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts")
KExpertsCPU
.
input_tensor_cpu
.
copy_
(
input_tensor
,
non_blocking
=
True
)
...
...
@@ -636,7 +636,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
topk_idx
,
topk_weight
,
aux_loss
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
):
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
()
:
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
...
...
ktransformers/operators/linear.py
View file @
f536a708
...
...
@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang
Date : 2024-07-25 11:25:24
Version : 0.1.0
LastEditors : Azure
LastEditTime : 2024-08-
14 14:57:04
LastEditTime : 2024-08-
29 09:11:16
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
...
...
@@ -277,7 +277,7 @@ class KLinearCPUInfer(KLinearBase):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
origin_shape
=
x
.
shape
# [batch_size, q_len, hidden_size]
if
origin_shape
[
1
]
==
1
:
if
origin_shape
[
1
]
==
1
and
torch
.
cuda
.
is_current_stream_capturing
()
:
out_device
=
x
.
device
self
.
input_tensor_cpu
.
copy_
(
x
,
non_blocking
=
True
)
qlen
=
origin_shape
[
1
]
...
...
ktransformers/operators/models.py
View file @
f536a708
...
...
@@ -670,8 +670,9 @@ class KDeepseekV2Model(BaseInjectedModule):
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
if
cur_device
not
in
self
.
stream_device_map
:
if
cur_device
not
in
self
.
stream_device_map
and
cur_device
.
lower
()
!=
"cpu"
:
self
.
stream_device_map
[
cur_device
]
=
torch
.
cuda
.
Stream
(
cur_device
)
if
cur_device
.
lower
()
!=
"cpu"
:
torch
.
cuda
.
set_device
(
cur_device
)
self
.
stream_device_map
[
cur_device
].
wait_stream
(
prev_stream
)
torch
.
cuda
.
set_stream
(
self
.
stream_device_map
[
cur_device
])
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
View file @
f536a708
...
...
@@ -7,7 +7,7 @@
prefill_device
:
"
cpu"
-
match
:
name
:
"
^model
\\
.layers
\\
.([0-9])
\\
."
name
:
"
^model
\\
.layers
\\
.([0-9]
|[1][0-4]
)
\\
."
class
:
ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbedding
...
...
@@ -15,7 +15,7 @@
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([1][
0
-9])
\\
."
name
:
"
^model
\\
.layers
\\
.([
2][0-9]|[
1][
5
-9])
\\
."
class
:
ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbedding
...
...
@@ -23,7 +23,7 @@
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.([
2
][0-9])
\\
."
name
:
"
^model
\\
.layers
\\
.([
3
][0-9]
|[4][0-4]
)
\\
."
class
:
ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbedding
...
...
@@ -31,7 +31,7 @@
generate_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
-
match
:
name
:
"
^model
\\
.layers
\\
.([
34
5][0-9])
\\
."
name
:
"
^model
\\
.layers
\\
.([5][0-9]
|[4][5-9]
)
\\
."
class
:
ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbedding
...
...
@@ -40,7 +40,7 @@
prefill_device
:
"
cuda:3"
-
match
:
name
:
"
^model
\\
.layers
\\
.([0-9])
\\
.(?!self_attn).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.([0-9]
|[1][0-4]
)
\\
.(?!self_attn
\\
.kv_b_proj
).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
...
...
@@ -50,7 +50,7 @@
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([1][
0
-9])
\\
.(?!self_attn).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.([
2][0-9]|[
1][
5
-9])
\\
.(?!self_attn
\\
.kv_b_proj
).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
...
...
@@ -60,7 +60,7 @@
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([
2
][0-9])
\\
.(?!self_attn).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.([
3
][0-9]
|[4][0-4]
)
\\
.(?!self_attn
\\
.kv_b_proj
).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
...
...
@@ -70,7 +70,7 @@
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([
34
5][0-9])
\\
.(?!self_attn).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.([5][0-9]
|[4][5-9]
)
\\
.(?!self_attn
\\
.kv_b_proj
).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
...
...
@@ -81,7 +81,7 @@
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([0-9])
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
.([0-9]
|[1][0-4]
)
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV2MoE
# mlp module with custom forward function
...
...
@@ -89,7 +89,7 @@
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([1][
0
-9])
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
.([
2][0-9]|[
1][
5
-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV2MoE
# mlp module with custom forward function
...
...
@@ -97,7 +97,7 @@
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.([
2
][0-9])
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
.([
3
][0-9]
|[4][0-4]
)
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV2MoE
# mlp module with custom forward function
...
...
@@ -105,7 +105,7 @@
generate_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
-
match
:
name
:
"
^model
\\
.layers
\\
.([
34
5][0-9])
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
.([5][0-9]
|[4][5-9]
)
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV2MoE
# mlp module with custom forward function
...
...
@@ -114,7 +114,7 @@
prefill_device
:
"
cuda:3"
-
match
:
name
:
"
^model
\\
.layers
\\
.([0-9])
\\
.mlp
\\
.experts$"
name
:
"
^model
\\
.layers
\\
.([0-9]
|[1][0-4]
)
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
...
...
@@ -125,7 +125,7 @@
out_device
:
"
cuda:0"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([1][
0
-9])
\\
.mlp
\\
.experts$"
name
:
"
^model
\\
.layers
\\
.([
2][0-9]|[
1][
5
-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
...
...
@@ -136,7 +136,7 @@
out_device
:
"
cuda:1"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([
2
][0-9])
\\
.mlp
\\
.experts$"
name
:
"
^model
\\
.layers
\\
.([
3
][0-9]
|[4][0-4]
)
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
...
...
@@ -147,7 +147,7 @@
out_device
:
"
cuda:2"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([
34
5][0-9])
\\
.mlp
\\
.experts$"
name
:
"
^model
\\
.layers
\\
.([5][0-9]
|[4][5-9]
)
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
...
...
@@ -159,28 +159,28 @@
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([0-9])
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
.([0-9]
|[1][0-4]
)
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([1][
0
-9])
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
.([
2][0-9]|[
1][
5
-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.([
2
][0-9])
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
.([
3
][0-9]
|[4][0-4]
)
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
-
match
:
name
:
"
^model
\\
.layers
\\
.([
34
5][0-9])
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
.([5][0-9]
|[4][5-9]
)
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
...
...
@@ -194,33 +194,33 @@
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
transfer_map
:
1
0
:
"
cuda:1"
2
0
:
"
cuda:2"
30
:
"
cuda:3"
1
5
:
"
cuda:1"
3
0
:
"
cuda:2"
45
:
"
cuda:3"
-
match
:
name
:
"
^model
\\
.layers
\\
.([0-9])
\\
."
name
:
"
^model
\\
.layers
\\
.([0-9]
|[1][0-4]
)
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([1][
0
-9])
\\
.)"
name
:
"
(^model
\\
.layers
\\
.([
2][0-9]|[
1][
5
-9])
\\
.)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([
2
][0-9])
\\
.)"
name
:
"
(^model
\\
.layers
\\
.([
3
][0-9]
|[4][0-4]
)
\\
.)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([
34
5][0-9])
\\
.)|(^model.norm)|(^lm_head)"
name
:
"
(^model
\\
.layers
\\
.([5][0-9]
|[4][5-9]
)
\\
.)|(^model.norm)|(^lm_head)"
replace
:
class
:
"
default"
kwargs
:
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
View file @
f536a708
...
...
@@ -24,7 +24,7 @@
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn
\\
.kv_b_proj
).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
...
...
@@ -35,7 +35,7 @@
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([345][0-9])
\\
.(?!self_attn).*$"
# regular expression
name
:
"
^model
\\
.layers
\\
.([345][0-9])
\\
.(?!self_attn
\\
.kv_b_proj
).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment