Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
7527619f
Unverified
Commit
7527619f
authored
Feb 10, 2025
by
UnicornChan
Committed by
GitHub
Feb 10, 2025
Browse files
Merge pull request #122 from kvcache-ai/feat-DeepSeekV3
[Feat] add support to DeepSeekV3
parents
f4903d54
6f0fe953
Changes
32
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1210 additions
and
59 deletions
+1210
-59
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+126
-0
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+12
-12
ktransformers/operators/models.py
ktransformers/operators/models.py
+14
-6
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
...ize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
+143
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+143
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+63
-0
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+101
-31
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+7
-5
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+3
-1
ktransformers/util/modeling_rope_utils.py
ktransformers/util/modeling_rope_utils.py
+592
-0
requirements-local_chat.txt
requirements-local_chat.txt
+1
-1
setup.py
setup.py
+5
-3
No files found.
ktransformers/operators/gate.py
0 → 100644
View file @
7527619f
from
typing
import
Any
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
from
torch
import
Tensor
,
nn
import
torch.nn.functional
as
F
import
torch
import
sys
,
os
from
ktransformers.operators.base_operator
import
BaseInjectedModule
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Debug"
))
import
cpuinfer_ext
from
cpuinfer_ext.moe
import
MOEConfig
,
MOE
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
abc
import
ABC
,
abstractmethod
import
time
# class Base(BaseInjectedModule, ABC):
class
KMoEGateBase
(
ABC
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
**
kwargs
):
# super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
super
().
__init__
()
self
.
key
=
key
self
.
gguf_loader
=
gguf_loader
self
.
config
=
config
self
.
device
=
device
self
.
orig_module
=
orig_module
@
abstractmethod
def
forward
(
self
,
input_tensor
,
expert_ids
,
weights
):
pass
@
abstractmethod
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
=
"cpu"
,
warmup
:
bool
=
False
):
pass
@
abstractmethod
def
unload
():
pass
def
load_weights
(
self
,
override_key
:
str
|
None
=
None
,
device
:
str
=
"cpu"
):
res
=
{}
if
override_key
is
not
None
:
keys
=
override_key
else
:
keys
=
[
self
.
key
]
gate
=
None
up
=
None
down
=
None
gate_type
=
None
up_type
=
None
down_type
=
None
for
key
in
keys
:
key
=
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
if
key
+
".ffn_gate_inp.weight"
in
self
.
gguf_loader
.
tensor_info
:
targets
=
[
".ffn_gate_inp.weight"
,
".exp_probs_b.bias"
]
tensors
=
self
.
load_multi
(
key
,
targets
,
device
=
device
)
weight
=
tensors
[
".ffn_gate_inp.weight"
]
e_score_correction_bias
=
tensors
[
".exp_probs_b.bias"
]
weight_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".ffn_gate_inp.weight"
][
"ggml_type"
]
e_score_correction_bias_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".exp_probs_b.bias"
][
"ggml_type"
]
else
:
raise
ValueError
(
f
"Experts
{
key
}
not found in gguf_loader"
)
res
=
{
"weight"
:
weight
,
"e_score_correction_bias"
:
e_score_correction_bias
,
"weight_type"
:
weight_type
,
"e_score_correction_bias_type"
:
e_score_correction_bias_type
}
return
res
def
load_multi
(
self
,
key
:
str
,
keys
:
list
[
str
],
device
:
str
=
"cpu"
):
tensors
=
{}
for
k
in
keys
:
tensors
[
k
]
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
k
,
device
=
device
)
return
tensors
class
KMoEGate
(
BaseInjectedModule
,
KMoEGateBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
return
self
.
orig_module
.
forward
(
hidden_states
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)
if
isinstance
(
w
,
dict
):
self
.
weight_type
=
w
[
"weight_type"
]
self
.
e_score_correction_bias_type
=
w
[
"e_score_correction_bias_type"
]
self
.
orig_module
.
weight
=
nn
.
Parameter
(
w
[
"weight"
])
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
w
[
"e_score_correction_bias"
])
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
ktransformers/operators/linear.py
View file @
7527619f
...
...
@@ -54,15 +54,15 @@ class KLinearBase(ABC):
self
.
has_bias
=
False
self
.
dtype
=
torch
.
get_default_dtype
()
if
orig_module
is
not
None
:
self
.
in_features
=
orig_module
.
in_features
self
.
out_features
=
orig_module
.
out_features
else
:
shape
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
]
if
len
(
shape
)
==
1
:
print
(
"Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF"
)
self
.
in_features
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
][
0
]
self
.
out_features
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
][
1
]
#
if orig_module is not None:
#
self.in_features = orig_module.in_features
#
self.out_features = orig_module.out_features
#
else:
shape
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
]
if
len
(
shape
)
==
1
:
print
(
"Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF"
)
self
.
in_features
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
][
0
]
self
.
out_features
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
][
1
]
@
abstractmethod
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -138,10 +138,10 @@ class KLinearTorch(KLinearBase):
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
self
.
w
=
w
.
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
self
.
w
=
w
.
to
(
dtype
=
self
.
dtype
).
T
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
self
.
w
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
self
.
w
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
T
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
dtype
)
self
.
has_bias
=
True
else
:
...
...
@@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase):
x
=
x
.
to
(
self
.
device
)
orig_shape
=
list
(
x
.
shape
)
orig_dtype
=
x
.
dtype
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
reshape
(
-
1
,
orig_
shape
[
-
1
])
marlin_s
=
self
.
marlin_s
.
to
(
x
.
dtype
)
x
=
KTransformersOps
.
gptq_marlin_gemm
(
x
,
...
...
ktransformers/operators/models.py
View file @
7527619f
...
...
@@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule):
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
if
inputs_embeds
is
None
:
org_device
=
input_ids
.
device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
).
to
(
org_device
)
input_ids
=
input_ids
.
to
(
org_device
)
if
cache_position
is
None
:
past_seen_tokens
=
(
...
...
@@ -639,12 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule):
if
position_ids
is
None
:
position_ids
=
cache_position
.
unsqueeze
(
0
)
if
inputs_embeds
is
None
:
org_device
=
input_ids
.
device
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
input_ids
=
input_ids
.
to
(
org_device
)
if
per_layer_prefill_flag
:
causal_mask
=
None
else
:
...
...
@@ -716,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule):
self
.
load_layer_to
(
decoder_layer
,
InferenceState
.
PREFILL
)
torch
.
cuda
.
empty_cache
()
t4
=
time
.
time
()
# with open("log.txt", "a") as f:
# f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
causal_mask
,
...
...
@@ -737,6 +740,7 @@ class KDeepseekV2Model(BaseInjectedModule):
hidden_states
=
layer_outputs
[
0
]
# @@@@@@@ TODO open this notes, tmp close to fit deepseekv3
if
use_cache
:
next_decoder_cache
=
layer_outputs
[
2
if
output_attentions
else
1
]
...
...
@@ -744,6 +748,10 @@ class KDeepseekV2Model(BaseInjectedModule):
all_self_attns
+=
(
layer_outputs
[
1
],)
hidden_states
=
self
.
norm
(
hidden_states
)
# with open("log.txt", "a") as f:
# f.write(f"@@@After layers\n")
# f.write(f"hidden_states={hidden_states}\n")
# f.write(f"hidden_states.shape={hidden_states.shape}\n")
if
per_layer_prefill_flag
:
t6
=
time
.
time
()
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
0 → 100644
View file @
7527619f
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:0"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:0"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:1"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:1"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
transfer_map
:
30
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)|(lm_head)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
0 → 100644
View file @
7527619f
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:0"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:0"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:1"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:1"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
transfer_map
:
30
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)|(lm_head)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
0 → 100644
View file @
7527619f
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/server/backend/interfaces/ktransformers.py
View file @
7527619f
...
...
@@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
self
.
args
=
args
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
Tru
e
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_cod
e
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
...
...
@@ -46,51 +46,61 @@ class KTransformersInterface(TransformersInterface):
)
optimize_and_load_gguf
(
self
.
model
,
optimize_rule_path
,
gguf_path
,
config
)
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
logger
.
info
(
f
"
{
args
.
model_name
}
loaded from
{
args
.
model_dir
}
to
{
device_map
}
"
)
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
#
logger.info(f"{args.model_name} loaded from {args.model_dir} to {
self.
device_map}")
self
.
cache
=
StaticCache
(
config
=
self
.
model
.
config
,
max_batch_size
=
args
.
batch_size
,
max_cache_len
=
args
.
cache_lens
,
device
=
device_map
,
device
=
self
.
device_map
,
dtype
=
self
.
model
.
dtype
,
)
logger
.
info
(
f
"StaticCache (length=
{
args
.
cache_lens
}
) created at
{
device_map
}
, batch size:
{
args
.
batch_size
}
"
)
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try
:
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.7
,
top_p
=
0.9
,
do_sample
=
True
)
self
.
model
.
generation_config
=
gen_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
def
decode_one_tokens
(
self
):
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
.
capture
(
self
.
model
,
self
.
current_ids
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
active_cache_position
,
self
.
cache
,
main_device
=
torch_device
,
return_dict
=
False
,
use_cache
=
True
,
)
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
if
self
.
args
.
use_cuda_graph
:
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
.
capture
(
self
.
model
,
self
.
current_ids
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
active_cache_position
,
self
.
cache
,
main_device
=
torch_device
,
return_dict
=
False
,
use_cache
=
True
,
)
if
hasattr
(
self
,
"cuda_graph_runner"
):
logits
=
self
.
cuda_graph_runner
(
self
.
current_ids
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
active_cache_position
)
self
.
cache
.
change_seq_length
(
1
)
torch
.
cuda
.
synchronize
()
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
if
hasattr
(
self
,
"cuda_graph_runner"
):
logits
=
self
.
cuda_graph_runner
(
self
.
current_ids
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
active_cache_position
)
self
.
cache
.
change_seq_length
(
1
)
torch
.
cuda
.
synchronize
()
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
self
.
current_ids
,
self
.
current_ids
.
to
(
torch_device
)
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
...
...
@@ -102,3 +112,63 @@ class KTransformersInterface(TransformersInterface):
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
self
.
seq_length
=
input_ids_length
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
else
:
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
device
)
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
@
property
def
active_cache_position
(
self
):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
7527619f
...
...
@@ -134,7 +134,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
args
.
model_dir
,
device_map
=
args
.
device
,
use_safetensors
=
True
)
logger
.
info
(
f
"
{
args
.
model_name
}
loaded from
{
args
.
model_dir
}
to
{
args
.
device
}
"
)
#
logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
self
.
cache
=
StaticCache
(
config
=
self
.
model
.
config
,
...
...
@@ -143,7 +143,7 @@ class TransformersInterface(BackendInterfaceBase):
device
=
args
.
device
,
dtype
=
self
.
model
.
dtype
,
)
logger
.
info
(
f
"StaticCache (length=
{
args
.
cache_lens
}
) created at
{
args
.
device
}
, batch size:
{
args
.
batch_size
}
"
)
#
logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
...
...
@@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
streamer
.
put
(
new_tokens
)
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
logits
=
logits
/
self
.
args
.
temperature
logits
=
logits
/
self
.
args
.
temperature
if
self
.
args
.
temperature
!=
0
else
logits
for
token_idx
in
self
.
ever_generated_ids
:
if
logits
[
token_idx
]
<
0
:
...
...
@@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase):
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
...
...
@@ -327,14 +329,14 @@ class TransformersInterface(BackendInterfaceBase):
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
if
t
is
not
None
:
print
(
t
,
end
=
""
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
in
self
.
generate
():
if
t
is
not
None
:
print
(
t
,
end
=
""
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
...
...
ktransformers/server/config/config.py
View file @
7527619f
...
...
@@ -93,6 +93,8 @@ class Config(metaclass=Singleton):
self
.
model_name
:
str
=
self
.
model
.
get
(
"name"
,
""
)
self
.
model_device
:
str
=
self
.
model
.
get
(
"device"
,
"cuda:0"
)
self
.
gguf_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"gguf_path"
,
None
)
self
.
use_cuda_graph
=
self
.
model
.
get
(
"use_cuda_graph"
,
True
)
self
.
trust_remote_code
=
self
.
model
.
get
(
"trust_remote_code"
,
True
)
# self.model_cache_lens = self.model.get("cache_lens")
self
.
optimize_config_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"optimize_config_path"
,
None
...
...
@@ -102,7 +104,7 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
5
00
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
20
00
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
...
...
ktransformers/util/modeling_rope_utils.py
0 → 100644
View file @
7527619f
This diff is collapsed.
Click to expand it.
requirements-local_chat.txt
View file @
7527619f
fire
transformers
transformers
==4.43.2
numpy
torch>=2.3.0
packaging
...
...
setup.py
View file @
7527619f
...
...
@@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension):
if
"CMAKE_BUILD_PARALLEL_LEVEL"
not
in
os
.
environ
:
if
hasattr
(
self
,
"parallel"
)
and
self
.
parallel
:
build_args
+=
[
f
"-j
{
self
.
parallel
}
"
]
print
(
"CMake args:"
,
cmake_args
)
build_temp
=
Path
(
ext
.
sourcedir
)
/
"build"
if
not
build_temp
.
exists
():
build_temp
.
mkdir
(
parents
=
True
)
subprocess
.
run
(
[
"cmake"
,
ext
.
sourcedir
,
*
cmake_args
],
cwd
=
build_temp
,
check
=
True
result
=
subprocess
.
run
(
[
"cmake"
,
ext
.
sourcedir
,
*
cmake_args
],
cwd
=
build_temp
,
check
=
True
,
capture_output
=
True
)
print
(
"Standard output:"
,
result
.
stdout
)
print
(
"Standard error:"
,
result
.
stderr
)
subprocess
.
run
(
[
"cmake"
,
"--build"
,
"."
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment