Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wqshmzh
ktransformers
Commits
f748cd29
You need to sign in or sign up before continuing.
Commit
f748cd29
authored
Feb 01, 2025
by
Azure
Browse files
fix rope; update moegate
parent
f873558a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
54 additions
and
21 deletions
+54
-21
ktransformers/models/modeling_deepseek_v3.py
ktransformers/models/modeling_deepseek_v3.py
+20
-15
ktransformers/operators/RoPE.py
ktransformers/operators/RoPE.py
+28
-0
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+1
-1
ktransformers/operators/models.py
ktransformers/operators/models.py
+1
-1
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+4
-4
No files found.
ktransformers/models/modeling_deepseek_v3.py
View file @
f748cd29
...
...
@@ -142,37 +142,42 @@ class DeepseekV3TopkRouter(nn.Module):
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
n_group
=
config
.
n_group
self
.
topk_group
=
config
.
topk_group
self
.
norm_topk_prob
=
config
.
norm_topk_prob
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
,
config
.
hidden_size
)))
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
)))
def
forward
(
self
,
hidden_states
):
batch_size
,
seq_length
=
hidden_states
.
shape
[:
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
config
.
hidden_size
)
router_logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
))
scores
=
router_logits
.
sigmoid
()
topk_indices
=
self
.
get_topk_indices
(
scores
)
topk_weights
=
scores
.
gather
(
1
,
topk_indices
)
if
self
.
norm_topk_prob
:
denominator
=
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
topk_weights
/=
denominator
topk_weights
=
topk_weights
*
self
.
routed_scaling_factor
return
topk_indices
,
topk_weights
,
router_logits
@
torch
.
no_grad
()
def
get_topk_indices
(
self
,
scores
):
scores_for_choice
=
scores
.
view
(
-
1
,
self
.
n_routed_experts
)
+
self
.
e_score_correction_bias
.
unsqueeze
(
0
)
group_scores
=
(
scores_for_choice
.
view
(
-
1
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
.
topk
(
2
,
dim
=-
1
)[
0
]
.
sum
(
dim
=-
1
)
)
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
self
.
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
self
.
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
group_mask
=
torch
.
zeros_like
(
group_scores
)
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
batch_size
*
seq_length
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
.
expand
(
-
1
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
.
reshape
(
-
1
,
self
.
n_routed_experts
)
)
# [n, e]
scores_for_choice
=
scores_for_choice
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
_
,
topk_indices
=
torch
.
topk
(
scores_for_choice
,
k
=
self
.
top_k
,
dim
=-
1
,
sorted
=
False
)
topk_weights
=
scores
.
gather
(
1
,
topk_indices
)
denominator
=
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
topk_weights
/=
denominator
topk_weights
=
topk_weights
*
self
.
routed_scaling_factor
# must multiply the scaling factor
return
topk_indices
,
topk_weights
,
router_logits
)
scores_for_choice
=
scores_for_choice
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
topk_indices
=
torch
.
topk
(
scores_for_choice
,
k
=
self
.
top_k
,
dim
=-
1
,
sorted
=
False
)[
1
]
return
topk_indices
class
DeepseekV3MoE
(
nn
.
Module
):
...
...
ktransformers/operators/RoPE.py
View file @
f748cd29
...
...
@@ -12,6 +12,9 @@ from ktransformers.models.modeling_llama import (
LlamaLinearScalingRotaryEmbedding
,
LlamaDynamicNTKScalingRotaryEmbedding
,
)
from
ktransformers.models.modeling_deepseek_v3
import
(
DeepseekV3RotaryEmbedding
)
from
ktransformers.models.modeling_deepseek
import
(
DeepseekV2YarnRotaryEmbedding
,
DeepseekV2RotaryEmbedding
,
...
...
@@ -134,6 +137,31 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
self
.
orig_module
.
mscale_all_dim
,
)
class
DeepSeekV3YarnRotaryEmbedding
(
BaseInjectedModule
,
DeepseekV3RotaryEmbedding
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
# device: str = "cuda",
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
def
load
(
self
):
# TODO support perlayer prefill
self
.
orig_module
.
__init__
(
self
.
config
,
device
=
self
.
generate_device
)
return
class
DynamicNTKScalingRotaryEmbedding
(
BaseInjectedModule
,
LlamaDynamicNTKScalingRotaryEmbedding
...
...
ktransformers/operators/linear.py
View file @
f748cd29
...
...
@@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase):
x
=
x
.
to
(
self
.
device
)
orig_shape
=
list
(
x
.
shape
)
orig_dtype
=
x
.
dtype
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
reshape
(
-
1
,
orig_
shape
[
-
1
])
marlin_s
=
self
.
marlin_s
.
to
(
x
.
dtype
)
x
=
KTransformersOps
.
gptq_marlin_gemm
(
x
,
...
...
ktransformers/operators/models.py
View file @
f748cd29
...
...
@@ -643,7 +643,7 @@ class KDeepseekV2Model(BaseInjectedModule):
org_device
=
input_ids
.
device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
.
to
(
org_device
)
input_ids
=
input_ids
.
to
(
org_device
)
if
per_layer_prefill_flag
:
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
f748cd29
...
...
@@ -8,17 +8,17 @@
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek.DeepseekV
2Yarn
RotaryEmbedding
class
:
ktransformers.models.modeling_deepseek
_v3
.DeepseekV
3
RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbedding
class
:
ktransformers.operators.RoPE.
DeepSeekV3
YarnRotaryEmbedding
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek.DeepseekV
2Yarn
RotaryEmbedding
class
:
ktransformers.models.modeling_deepseek
_v3
.DeepseekV
3
RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbedding
class
:
ktransformers.operators.RoPE.
DeepSeekV3
YarnRotaryEmbedding
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment