Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
7527619f
Unverified
Commit
7527619f
authored
Feb 10, 2025
by
UnicornChan
Committed by
GitHub
Feb 10, 2025
Browse files
Merge pull request #122 from kvcache-ai/feat-DeepSeekV3
[Feat] add support to DeepSeekV3
parents
f4903d54
6f0fe953
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1210 additions
and
59 deletions
+1210
-59
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+126
-0
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+12
-12
ktransformers/operators/models.py
ktransformers/operators/models.py
+14
-6
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
...ize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
+143
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+143
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+63
-0
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+101
-31
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+7
-5
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+3
-1
ktransformers/util/modeling_rope_utils.py
ktransformers/util/modeling_rope_utils.py
+592
-0
requirements-local_chat.txt
requirements-local_chat.txt
+1
-1
setup.py
setup.py
+5
-3
No files found.
ktransformers/operators/gate.py
0 → 100644
View file @
7527619f
from
typing
import
Any
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
from
torch
import
Tensor
,
nn
import
torch.nn.functional
as
F
import
torch
import
sys
,
os
from
ktransformers.operators.base_operator
import
BaseInjectedModule
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Debug"
))
import
cpuinfer_ext
from
cpuinfer_ext.moe
import
MOEConfig
,
MOE
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
abc
import
ABC
,
abstractmethod
import
time
# class Base(BaseInjectedModule, ABC):
class
KMoEGateBase
(
ABC
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
**
kwargs
):
# super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
super
().
__init__
()
self
.
key
=
key
self
.
gguf_loader
=
gguf_loader
self
.
config
=
config
self
.
device
=
device
self
.
orig_module
=
orig_module
@
abstractmethod
def
forward
(
self
,
input_tensor
,
expert_ids
,
weights
):
pass
@
abstractmethod
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
=
"cpu"
,
warmup
:
bool
=
False
):
pass
@
abstractmethod
def
unload
():
pass
def
load_weights
(
self
,
override_key
:
str
|
None
=
None
,
device
:
str
=
"cpu"
):
res
=
{}
if
override_key
is
not
None
:
keys
=
override_key
else
:
keys
=
[
self
.
key
]
gate
=
None
up
=
None
down
=
None
gate_type
=
None
up_type
=
None
down_type
=
None
for
key
in
keys
:
key
=
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
if
key
+
".ffn_gate_inp.weight"
in
self
.
gguf_loader
.
tensor_info
:
targets
=
[
".ffn_gate_inp.weight"
,
".exp_probs_b.bias"
]
tensors
=
self
.
load_multi
(
key
,
targets
,
device
=
device
)
weight
=
tensors
[
".ffn_gate_inp.weight"
]
e_score_correction_bias
=
tensors
[
".exp_probs_b.bias"
]
weight_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".ffn_gate_inp.weight"
][
"ggml_type"
]
e_score_correction_bias_type
=
self
.
gguf_loader
.
tensor_info
[
key
+
".exp_probs_b.bias"
][
"ggml_type"
]
else
:
raise
ValueError
(
f
"Experts
{
key
}
not found in gguf_loader"
)
res
=
{
"weight"
:
weight
,
"e_score_correction_bias"
:
e_score_correction_bias
,
"weight_type"
:
weight_type
,
"e_score_correction_bias_type"
:
e_score_correction_bias_type
}
return
res
def
load_multi
(
self
,
key
:
str
,
keys
:
list
[
str
],
device
:
str
=
"cpu"
):
tensors
=
{}
for
k
in
keys
:
tensors
[
k
]
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
k
,
device
=
device
)
return
tensors
class
KMoEGate
(
BaseInjectedModule
,
KMoEGateBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
return
self
.
orig_module
.
forward
(
hidden_states
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)
if
isinstance
(
w
,
dict
):
self
.
weight_type
=
w
[
"weight_type"
]
self
.
e_score_correction_bias_type
=
w
[
"e_score_correction_bias_type"
]
self
.
orig_module
.
weight
=
nn
.
Parameter
(
w
[
"weight"
])
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
w
[
"e_score_correction_bias"
])
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
ktransformers/operators/linear.py
View file @
7527619f
...
@@ -54,15 +54,15 @@ class KLinearBase(ABC):
...
@@ -54,15 +54,15 @@ class KLinearBase(ABC):
self
.
has_bias
=
False
self
.
has_bias
=
False
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
dtype
=
torch
.
get_default_dtype
()
if
orig_module
is
not
None
:
#
if orig_module is not None:
self
.
in_features
=
orig_module
.
in_features
#
self.in_features = orig_module.in_features
self
.
out_features
=
orig_module
.
out_features
#
self.out_features = orig_module.out_features
else
:
#
else:
shape
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
]
shape
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
]
if
len
(
shape
)
==
1
:
if
len
(
shape
)
==
1
:
print
(
"Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF"
)
print
(
"Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF"
)
self
.
in_features
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
][
0
]
self
.
in_features
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
][
0
]
self
.
out_features
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
][
1
]
self
.
out_features
=
self
.
gguf_loader
.
tensor_info
[
key
+
".weight"
][
"shape"
][
1
]
@
abstractmethod
@
abstractmethod
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -138,10 +138,10 @@ class KLinearTorch(KLinearBase):
...
@@ -138,10 +138,10 @@ class KLinearTorch(KLinearBase):
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
if
isinstance
(
w
,
nn
.
Parameter
):
self
.
w
=
w
.
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
self
.
w
=
w
.
to
(
dtype
=
self
.
dtype
).
T
self
.
has_bias
=
False
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
elif
isinstance
(
w
,
tuple
):
self
.
w
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
view
(
self
.
out_features
,
self
.
in_features
).
T
self
.
w
=
w
[
0
].
to
(
dtype
=
self
.
dtype
).
T
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
dtype
)
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
dtype
)
self
.
has_bias
=
True
self
.
has_bias
=
True
else
:
else
:
...
@@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase):
...
@@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase):
x
=
x
.
to
(
self
.
device
)
x
=
x
.
to
(
self
.
device
)
orig_shape
=
list
(
x
.
shape
)
orig_shape
=
list
(
x
.
shape
)
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
reshape
(
-
1
,
orig_
shape
[
-
1
])
marlin_s
=
self
.
marlin_s
.
to
(
x
.
dtype
)
marlin_s
=
self
.
marlin_s
.
to
(
x
.
dtype
)
x
=
KTransformersOps
.
gptq_marlin_gemm
(
x
=
KTransformersOps
.
gptq_marlin_gemm
(
x
,
x
,
...
...
ktransformers/operators/models.py
View file @
7527619f
...
@@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule):
if
use_legacy_cache
:
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
if
inputs_embeds
is
None
:
org_device
=
input_ids
.
device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
).
to
(
org_device
)
input_ids
=
input_ids
.
to
(
org_device
)
if
cache_position
is
None
:
if
cache_position
is
None
:
past_seen_tokens
=
(
past_seen_tokens
=
(
...
@@ -639,12 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -639,12 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule):
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
cache_position
.
unsqueeze
(
0
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
if
inputs_embeds
is
None
:
org_device
=
input_ids
.
device
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
input_ids
=
input_ids
.
to
(
org_device
)
if
per_layer_prefill_flag
:
if
per_layer_prefill_flag
:
causal_mask
=
None
causal_mask
=
None
else
:
else
:
...
@@ -716,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -716,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule):
self
.
load_layer_to
(
decoder_layer
,
InferenceState
.
PREFILL
)
self
.
load_layer_to
(
decoder_layer
,
InferenceState
.
PREFILL
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
t4
=
time
.
time
()
t4
=
time
.
time
()
# with open("log.txt", "a") as f:
# f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
layer_outputs
=
decoder_layer
(
layer_outputs
=
decoder_layer
(
hidden_states
,
hidden_states
,
attention_mask
=
causal_mask
,
attention_mask
=
causal_mask
,
...
@@ -737,6 +740,7 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -737,6 +740,7 @@ class KDeepseekV2Model(BaseInjectedModule):
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
# @@@@@@@ TODO open this notes, tmp close to fit deepseekv3
if
use_cache
:
if
use_cache
:
next_decoder_cache
=
layer_outputs
[
2
if
output_attentions
else
1
]
next_decoder_cache
=
layer_outputs
[
2
if
output_attentions
else
1
]
...
@@ -744,6 +748,10 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -744,6 +748,10 @@ class KDeepseekV2Model(BaseInjectedModule):
all_self_attns
+=
(
layer_outputs
[
1
],)
all_self_attns
+=
(
layer_outputs
[
1
],)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
# with open("log.txt", "a") as f:
# f.write(f"@@@After layers\n")
# f.write(f"hidden_states={hidden_states}\n")
# f.write(f"hidden_states.shape={hidden_states.shape}\n")
if
per_layer_prefill_flag
:
if
per_layer_prefill_flag
:
t6
=
time
.
time
()
t6
=
time
.
time
()
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
0 → 100644
View file @
7527619f
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:0"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:0"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:1"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:1"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
transfer_map
:
30
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)|(lm_head)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
0 → 100644
View file @
7527619f
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:0"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:0"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:1"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:1"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
transfer_map
:
30
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)|(lm_head)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
0 → 100644
View file @
7527619f
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/server/backend/interfaces/ktransformers.py
View file @
7527619f
...
@@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
...
@@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
self
.
args
=
args
self
.
args
=
args
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
Tru
e
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_cod
e
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
config
.
_attn_implementation
=
"flash_attention_2"
...
@@ -46,51 +46,61 @@ class KTransformersInterface(TransformersInterface):
...
@@ -46,51 +46,61 @@ class KTransformersInterface(TransformersInterface):
)
)
optimize_and_load_gguf
(
self
.
model
,
optimize_rule_path
,
gguf_path
,
config
)
optimize_and_load_gguf
(
self
.
model
,
optimize_rule_path
,
gguf_path
,
config
)
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
logger
.
info
(
f
"
{
args
.
model_name
}
loaded from
{
args
.
model_dir
}
to
{
device_map
}
"
)
#
logger.info(f"{args.model_name} loaded from {args.model_dir} to {
self.
device_map}")
self
.
cache
=
StaticCache
(
self
.
cache
=
StaticCache
(
config
=
self
.
model
.
config
,
config
=
self
.
model
.
config
,
max_batch_size
=
args
.
batch_size
,
max_batch_size
=
args
.
batch_size
,
max_cache_len
=
args
.
cache_lens
,
max_cache_len
=
args
.
cache_lens
,
device
=
device_map
,
device
=
self
.
device_map
,
dtype
=
self
.
model
.
dtype
,
dtype
=
self
.
model
.
dtype
,
)
)
logger
.
info
(
f
"StaticCache (length=
{
args
.
cache_lens
}
) created at
{
device_map
}
, batch size:
{
args
.
batch_size
}
"
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
try
:
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.7
,
top_p
=
0.9
,
do_sample
=
True
)
self
.
model
.
generation_config
=
gen_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
def
decode_one_tokens
(
self
):
def
decode_one_tokens
(
self
):
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
if
self
.
args
.
use_cuda_graph
:
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
self
.
cuda_graph_runner
.
capture
(
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
model
,
self
.
cuda_graph_runner
.
capture
(
self
.
current_ids
,
self
.
model
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
current_ids
,
self
.
active_cache_position
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
cache
,
self
.
active_cache_position
,
main_device
=
torch_device
,
self
.
cache
,
return_dict
=
False
,
main_device
=
torch_device
,
use_cache
=
True
,
return_dict
=
False
,
)
use_cache
=
True
,
)
if
hasattr
(
self
,
"cuda_graph_runner"
):
if
hasattr
(
self
,
"cuda_graph_runner"
):
logits
=
self
.
cuda_graph_runner
(
logits
=
self
.
cuda_graph_runner
(
self
.
current_ids
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
active_cache_position
self
.
current_ids
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
active_cache_position
)
)
self
.
cache
.
change_seq_length
(
1
)
self
.
cache
.
change_seq_length
(
1
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
logits
=
logits
[
0
,
-
1
,
:]
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
return
self
.
logits_to_token
(
logits
)
if
self
.
use_static_cache
:
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
logits
=
self
.
model
(
self
.
current_ids
,
self
.
current_ids
.
to
(
torch_device
)
,
cache_position
=
self
.
active_cache_position
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
attention_mask
=
mask
,
...
@@ -102,3 +112,63 @@ class KTransformersInterface(TransformersInterface):
...
@@ -102,3 +112,63 @@ class KTransformersInterface(TransformersInterface):
logits
=
logits
[
0
,
-
1
,
:]
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
if
is_new
:
self
.
cache
.
reset
()
self
.
ever_generated_ids
.
clear
()
former_seq_length
=
0
self
.
seq_length
=
input_ids_length
self
.
generated_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
,
)
else
:
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
device
)
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
@
property
def
active_cache_position
(
self
):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
7527619f
...
@@ -134,7 +134,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -134,7 +134,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
args
.
model_dir
,
device_map
=
args
.
device
,
use_safetensors
=
True
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
args
.
model_dir
,
device_map
=
args
.
device
,
use_safetensors
=
True
)
logger
.
info
(
f
"
{
args
.
model_name
}
loaded from
{
args
.
model_dir
}
to
{
args
.
device
}
"
)
#
logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
self
.
cache
=
StaticCache
(
self
.
cache
=
StaticCache
(
config
=
self
.
model
.
config
,
config
=
self
.
model
.
config
,
...
@@ -143,7 +143,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -143,7 +143,7 @@ class TransformersInterface(BackendInterfaceBase):
device
=
args
.
device
,
device
=
args
.
device
,
dtype
=
self
.
model
.
dtype
,
dtype
=
self
.
model
.
dtype
,
)
)
logger
.
info
(
f
"StaticCache (length=
{
args
.
cache_lens
}
) created at
{
args
.
device
}
, batch size:
{
args
.
batch_size
}
"
)
#
logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
...
@@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
logits
=
logits
/
self
.
args
.
temperature
logits
=
logits
/
self
.
args
.
temperature
if
self
.
args
.
temperature
!=
0
else
logits
for
token_idx
in
self
.
ever_generated_ids
:
for
token_idx
in
self
.
ever_generated_ids
:
if
logits
[
token_idx
]
<
0
:
if
logits
[
token_idx
]
<
0
:
...
@@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase):
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
else
:
raise
ValueError
(
"local_messages should be List or str"
)
raise
ValueError
(
"local_messages should be List or str"
)
...
@@ -327,14 +329,14 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -327,14 +329,14 @@ class TransformersInterface(BackendInterfaceBase):
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"prefill"
)
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
pause_timer
(
"prefill"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
self
.
profiler
.
create_and_start_timer
(
"decode"
)
for
t
in
self
.
generate
():
for
t
in
self
.
generate
():
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
)
print
(
t
,
end
=
""
,
flush
=
True
)
yield
t
yield
t
print
(
""
)
print
(
""
)
self
.
profiler
.
pause_timer
(
"decode"
)
self
.
profiler
.
pause_timer
(
"decode"
)
...
...
ktransformers/server/config/config.py
View file @
7527619f
...
@@ -93,6 +93,8 @@ class Config(metaclass=Singleton):
...
@@ -93,6 +93,8 @@ class Config(metaclass=Singleton):
self
.
model_name
:
str
=
self
.
model
.
get
(
"name"
,
""
)
self
.
model_name
:
str
=
self
.
model
.
get
(
"name"
,
""
)
self
.
model_device
:
str
=
self
.
model
.
get
(
"device"
,
"cuda:0"
)
self
.
model_device
:
str
=
self
.
model
.
get
(
"device"
,
"cuda:0"
)
self
.
gguf_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"gguf_path"
,
None
)
self
.
gguf_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"gguf_path"
,
None
)
self
.
use_cuda_graph
=
self
.
model
.
get
(
"use_cuda_graph"
,
True
)
self
.
trust_remote_code
=
self
.
model
.
get
(
"trust_remote_code"
,
True
)
# self.model_cache_lens = self.model.get("cache_lens")
# self.model_cache_lens = self.model.get("cache_lens")
self
.
optimize_config_path
:
Optional
[
str
]
=
self
.
model
.
get
(
self
.
optimize_config_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"optimize_config_path"
,
None
"optimize_config_path"
,
None
...
@@ -102,7 +104,7 @@ class Config(metaclass=Singleton):
...
@@ -102,7 +104,7 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
5
00
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
20
00
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
...
...
ktransformers/util/modeling_rope_utils.py
0 → 100644
View file @
7527619f
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
typing
import
Optional
,
Tuple
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
is_torch_available
,
logging
logger
=
logging
.
get_logger
(
__name__
)
if
is_torch_available
():
import
torch
def
_compute_default_rope_parameters
(
config
:
Optional
[
PretrainedConfig
]
=
None
,
device
:
Optional
[
"torch.device"
]
=
None
,
seq_len
:
Optional
[
int
]
=
None
,
**
rope_kwargs
,
)
->
Tuple
[
"torch.Tensor"
,
float
]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if
config
is
not
None
and
len
(
rope_kwargs
)
>
0
:
raise
ValueError
(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f
"`_compute_default_rope_parameters`, got `rope_kwargs`=
{
rope_kwargs
}
and `config`=
{
config
}
"
)
if
len
(
rope_kwargs
)
>
0
:
base
=
rope_kwargs
[
"base"
]
dim
=
rope_kwargs
[
"dim"
]
elif
config
is
not
None
:
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
attention_factor
=
1.0
# Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
device
)
/
dim
))
return
inv_freq
,
attention_factor
def
_compute_linear_scaling_rope_parameters
(
config
:
Optional
[
PretrainedConfig
]
=
None
,
device
:
Optional
[
"torch.device"
]
=
None
,
seq_len
:
Optional
[
int
]
=
None
,
**
rope_kwargs
,
)
->
Tuple
[
"torch.Tensor"
,
float
]:
"""
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if
config
is
not
None
and
len
(
rope_kwargs
)
>
0
:
raise
ValueError
(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f
"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`=
{
rope_kwargs
}
and `config`=
{
config
}
"
)
if
len
(
rope_kwargs
)
>
0
:
factor
=
rope_kwargs
[
"factor"
]
elif
config
is
not
None
:
factor
=
config
.
rope_scaling
[
"factor"
]
# Gets the default RoPE parameters
inv_freq
,
attention_factor
=
_compute_default_rope_parameters
(
config
,
device
,
seq_len
,
**
rope_kwargs
)
# Then applies linear scaling to the frequencies.
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
# applying scaling to the inverse frequencies is equivalent.
inv_freq
/=
factor
return
inv_freq
,
attention_factor
def
_compute_dynamic_ntk_parameters
(
config
:
Optional
[
PretrainedConfig
]
=
None
,
device
:
Optional
[
"torch.device"
]
=
None
,
seq_len
:
Optional
[
int
]
=
None
,
**
rope_kwargs
,
)
->
Tuple
[
"torch.Tensor"
,
float
]:
"""
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length, used to update the dynamic RoPE at inference time.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
if
config
is
not
None
and
len
(
rope_kwargs
)
>
0
:
raise
ValueError
(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f
"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`=
{
rope_kwargs
}
and `config`=
{
config
}
"
)
if
len
(
rope_kwargs
)
>
0
:
base
=
rope_kwargs
[
"base"
]
dim
=
rope_kwargs
[
"dim"
]
max_position_embeddings
=
rope_kwargs
[
"max_position_embeddings"
]
factor
=
rope_kwargs
[
"factor"
]
elif
config
is
not
None
:
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
1.0
# Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len
=
seq_len
if
seq_len
is
not
None
and
seq_len
>
max_position_embeddings
else
max_position_embeddings
# Compute the inverse frequencies
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
device
)
/
dim
))
return
inv_freq
,
attention_factor
def
_compute_yarn_parameters
(
config
:
PretrainedConfig
,
device
:
"torch.device"
,
seq_len
:
Optional
[
int
]
=
None
,
**
rope_kwargs
)
->
Tuple
[
"torch.Tensor"
,
float
]:
"""
Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://arxiv.org/abs/2309.00071)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# No need to keep BC with yarn, unreleased when this new pattern was created.
if
len
(
rope_kwargs
)
>
0
:
raise
ValueError
(
f
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got
{
rope_kwargs
}
"
)
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
mscale
=
config
.
rope_scaling
.
get
(
"mscale"
)
mscale_all_dim
=
config
.
rope_scaling
.
get
(
"mscale_all_dim"
)
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if
"original_max_position_embeddings"
in
config
.
rope_scaling
:
original_max_position_embeddings
=
config
.
rope_scaling
[
"original_max_position_embeddings"
]
factor
=
config
.
max_position_embeddings
/
original_max_position_embeddings
else
:
original_max_position_embeddings
=
config
.
max_position_embeddings
def
get_mscale
(
scale
,
mscale
=
1
):
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
# Sets the attention factor as suggested in the paper
if
attention_factor
is
None
:
if
mscale
and
mscale_all_dim
:
attention_factor
=
float
(
get_mscale
(
factor
,
mscale
)
/
get_mscale
(
factor
,
mscale_all_dim
))
else
:
attention_factor
=
get_mscale
(
factor
)
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast
=
config
.
rope_scaling
.
get
(
"beta_fast"
)
or
32
beta_slow
=
config
.
rope_scaling
.
get
(
"beta_slow"
)
or
1
# Compute the inverse frequencies
def
find_correction_dim
(
num_rotations
,
dim
,
base
,
max_position_embeddings
):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return
(
dim
*
math
.
log
(
max_position_embeddings
/
(
num_rotations
*
2
*
math
.
pi
)))
/
(
2
*
math
.
log
(
base
))
def
find_correction_range
(
low_rot
,
high_rot
,
dim
,
base
,
max_position_embeddings
):
"""Find dimension range bounds based on rotations"""
low
=
math
.
floor
(
find_correction_dim
(
low_rot
,
dim
,
base
,
max_position_embeddings
))
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
def
linear_ramp_factor
(
min
,
max
,
dim
):
if
min
==
max
:
max
+=
0.001
# Prevent singularity
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
torch
.
float32
)
-
min
)
/
(
max
-
min
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
original_max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor
=
1
-
linear_ramp_factor
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_extrapolation_factor
)
+
inv_freq_extrapolation
*
inv_freq_extrapolation_factor
)
return
inv_freq
,
attention_factor
def
_compute_longrope_parameters
(
config
:
PretrainedConfig
,
device
:
"torch.device"
,
seq_len
:
Optional
[
int
]
=
None
,
**
rope_kwargs
)
->
Tuple
[
"torch.Tensor"
,
float
]:
"""
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
[original implementation](https://github.com/microsoft/LongRoPE)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
# No need to keep BC with longrope, unreleased when this new pattern was created.
if
len
(
rope_kwargs
)
>
0
:
raise
ValueError
(
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
f
"
{
rope_kwargs
}
"
)
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
long_factor
=
config
.
rope_scaling
[
"long_factor"
]
short_factor
=
config
.
rope_scaling
[
"short_factor"
]
factor
=
config
.
rope_scaling
.
get
(
"factor"
)
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if
hasattr
(
config
,
"original_max_position_embeddings"
):
original_max_position_embeddings
=
config
.
original_max_position_embeddings
factor
=
config
.
max_position_embeddings
/
config
.
original_max_position_embeddings
else
:
original_max_position_embeddings
=
config
.
max_position_embeddings
# Sets the attention factor as suggested in the paper
if
attention_factor
is
None
:
if
factor
<=
1.0
:
attention_factor
=
1.0
else
:
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
original_max_position_embeddings
))
# Compute the inverse frequencies -- scaled based on the target sequence length
if
seq_len
and
seq_len
>
original_max_position_embeddings
:
ext_factors
=
torch
.
tensor
(
long_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
else
:
ext_factors
=
torch
.
tensor
(
short_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
inv_freq_shape
=
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
int64
,
device
=
device
).
float
()
/
dim
inv_freq
=
1.0
/
(
ext_factors
*
base
**
inv_freq_shape
)
return
inv_freq
,
attention_factor
def
_compute_llama3_parameters
(
config
:
PretrainedConfig
,
device
:
"torch.device"
,
seq_len
:
Optional
[
int
]
=
None
,
**
rope_kwargs
)
->
Tuple
[
"torch.Tensor"
,
float
]:
"""
Computes the inverse frequencies for llama 3.1.
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# Gets the default RoPE parameters
inv_freq
,
attention_factor
=
_compute_default_rope_parameters
(
config
,
device
,
seq_len
,
**
rope_kwargs
)
factor
=
config
.
rope_scaling
[
"factor"
]
# `8` in the original implementation
low_freq_factor
=
config
.
rope_scaling
[
"low_freq_factor"
]
# `1` in the original implementation
high_freq_factor
=
config
.
rope_scaling
[
"high_freq_factor"
]
# `4` in the original implementation
old_context_len
=
config
.
rope_scaling
[
"original_max_position_embeddings"
]
# `8192` in the original implementation
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
wavelen
=
2
*
math
.
pi
/
inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
return
inv_freq_llama
,
attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
# parameterizations, as long as the callable has the same signature.
ROPE_INIT_FUNCTIONS
=
{
"default"
:
_compute_default_rope_parameters
,
"linear"
:
_compute_linear_scaling_rope_parameters
,
"dynamic"
:
_compute_dynamic_ntk_parameters
,
"yarn"
:
_compute_yarn_parameters
,
"longrope"
:
_compute_longrope_parameters
,
"llama3"
:
_compute_llama3_parameters
,
}
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
,
ignore_keys
:
Optional
[
set
]
=
None
,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
if
"type"
in
received_keys
:
received_keys
-=
{
"type"
}
required_keys
.
add
(
"rope_type"
)
# Some models need to store model-specific keys, and we don't want to throw warning at them
if
ignore_keys
is
not
None
:
received_keys
-=
ignore_keys
missing_keys
=
required_keys
-
received_keys
if
missing_keys
:
raise
KeyError
(
f
"Missing required keys in `rope_scaling` for 'rope_type'='
{
rope_type
}
':
{
missing_keys
}
"
)
if
optional_keys
is
not
None
:
unused_keys
=
received_keys
-
required_keys
-
optional_keys
else
:
unused_keys
=
received_keys
-
required_keys
if
unused_keys
:
logger
.
warning
(
f
"Unrecognized keys in `rope_scaling` for 'rope_type'='
{
rope_type
}
':
{
unused_keys
}
"
)
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
,
"original_max_position_embeddings"
,
"mscale"
,
"mscale_all_dim"
,
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
attention_factor
=
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
not
None
and
(
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0
):
logger
.
warning
(
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
)
beta_fast
=
rope_scaling
.
get
(
"beta_fast"
)
if
beta_fast
is
not
None
and
not
isinstance
(
beta_fast
,
float
):
logger
.
warning
(
f
"`rope_scaling`'s beta_fast field must be a float, got
{
beta_fast
}
"
)
beta_slow
=
rope_scaling
.
get
(
"beta_slow"
)
if
beta_slow
is
not
None
and
not
isinstance
(
beta_slow
,
float
):
logger
.
warning
(
f
"`rope_scaling`'s beta_slow field must be a float, got
{
beta_slow
}
"
)
if
(
beta_fast
or
32
)
<
(
beta_slow
or
1
):
logger
.
warning
(
f
"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast=
{
beta_fast
}
"
f
"(defaults to 32 if None) and beta_slow=
{
beta_slow
}
(defaults to 1 if None)"
)
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"short_factor"
,
"long_factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"attention_factor"
,
"factor"
,
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
short_factor
=
rope_scaling
.
get
(
"short_factor"
)
if
not
isinstance
(
short_factor
,
list
)
and
all
(
isinstance
(
x
,
(
int
,
float
))
for
x
in
short_factor
):
logger
.
warning
(
f
"`rope_scaling`'s short_factor field must be a list of numbers, got
{
short_factor
}
"
)
if
not
len
(
short_factor
)
==
dim
//
2
:
logger
.
warning
(
f
"`rope_scaling`'s short_factor field must have length
{
dim
//
2
}
, got
{
len
(
short_factor
)
}
"
)
long_factor
=
rope_scaling
.
get
(
"long_factor"
)
if
not
isinstance
(
long_factor
,
list
)
and
all
(
isinstance
(
x
,
(
int
,
float
))
for
x
in
long_factor
):
logger
.
warning
(
f
"`rope_scaling`'s long_factor field must be a list of numbers, got
{
long_factor
}
"
)
if
not
len
(
long_factor
)
==
dim
//
2
:
logger
.
warning
(
f
"`rope_scaling`'s long_factor field must have length
{
dim
//
2
}
, got
{
len
(
long_factor
)
}
"
)
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
# unique to longrope (= undesirable)
if
hasattr
(
config
,
"original_max_position_embeddings"
):
logger
.
warning_once
(
"This model has set a `original_max_position_embeddings` field, to be used together with "
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
"as it is compatible with most model architectures."
)
else
:
factor
=
rope_scaling
.
get
(
"factor"
)
if
factor
is
None
:
logger
.
warning
(
"Missing required keys in `rope_scaling`: 'factor'"
)
elif
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
attention_factor
=
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
not
None
:
if
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0.0
:
logger
.
warning
(
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
)
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
,
"original_max_position_embeddings"
,
"low_freq_factor"
,
"high_freq_factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
if
low_freq_factor
is
None
or
not
isinstance
(
low_freq_factor
,
float
):
logger
.
warning
(
f
"`rope_scaling`'s low_freq_factor field must be a float, got
{
low_freq_factor
}
"
)
if
high_freq_factor
is
None
or
not
isinstance
(
high_freq_factor
,
float
):
logger
.
warning
(
f
"`rope_scaling`'s high_freq_factor field must be a float, got
{
high_freq_factor
}
"
)
if
high_freq_factor
<=
low_freq_factor
:
logger
.
warning
(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f
"
{
high_freq_factor
}
and low_freq_factor=
{
low_freq_factor
}
"
)
original_max_position_embeddings
=
rope_scaling
[
"original_max_position_embeddings"
]
if
original_max_position_embeddings
is
None
or
not
isinstance
(
original_max_position_embeddings
,
int
):
logger
.
warning
(
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
f
"
{
original_max_position_embeddings
}
"
)
if
original_max_position_embeddings
>=
config
.
max_position_embeddings
:
logger
.
warning
(
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
f
"
{
original_max_position_embeddings
}
and max_position_embeddings=
{
config
.
max_position_embeddings
}
"
)
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
ROPE_VALIDATION_FUNCTIONS
=
{
"default"
:
_validate_default_rope_parameters
,
"linear"
:
_validate_linear_scaling_rope_parameters
,
"dynamic"
:
_validate_dynamic_scaling_rope_parameters
,
"yarn"
:
_validate_yarn_parameters
,
"longrope"
:
_validate_longrope_parameters
,
"llama3"
:
_validate_llama3_parameters
,
}
def
rope_config_validation
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
# not a default parameter in `PretrainedConfig`
if
rope_scaling
is
None
:
return
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
"default"
))
validation_fn
=
ROPE_VALIDATION_FUNCTIONS
.
get
(
rope_type
)
if
validation_fn
is
not
None
:
validation_fn
(
config
,
ignore_keys
=
ignore_keys
)
else
:
logger
.
warning
(
f
"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='
{
rope_type
}
'"
)
\ No newline at end of file
requirements-local_chat.txt
View file @
7527619f
fire
fire
transformers
transformers
==4.43.2
numpy
numpy
torch>=2.3.0
torch>=2.3.0
packaging
packaging
...
...
setup.py
View file @
7527619f
...
@@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension):
...
@@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension):
if
"CMAKE_BUILD_PARALLEL_LEVEL"
not
in
os
.
environ
:
if
"CMAKE_BUILD_PARALLEL_LEVEL"
not
in
os
.
environ
:
if
hasattr
(
self
,
"parallel"
)
and
self
.
parallel
:
if
hasattr
(
self
,
"parallel"
)
and
self
.
parallel
:
build_args
+=
[
f
"-j
{
self
.
parallel
}
"
]
build_args
+=
[
f
"-j
{
self
.
parallel
}
"
]
print
(
"CMake args:"
,
cmake_args
)
build_temp
=
Path
(
ext
.
sourcedir
)
/
"build"
build_temp
=
Path
(
ext
.
sourcedir
)
/
"build"
if
not
build_temp
.
exists
():
if
not
build_temp
.
exists
():
build_temp
.
mkdir
(
parents
=
True
)
build_temp
.
mkdir
(
parents
=
True
)
subprocess
.
run
(
result
=
subprocess
.
run
(
[
"cmake"
,
ext
.
sourcedir
,
*
cmake_args
],
cwd
=
build_temp
,
check
=
True
[
"cmake"
,
ext
.
sourcedir
,
*
cmake_args
],
cwd
=
build_temp
,
check
=
True
,
capture_output
=
True
)
)
print
(
"Standard output:"
,
result
.
stdout
)
print
(
"Standard error:"
,
result
.
stderr
)
subprocess
.
run
(
subprocess
.
run
(
[
"cmake"
,
"--build"
,
"."
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
[
"cmake"
,
"--build"
,
"."
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
)
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment