Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
8be56a01
Unverified
Commit
8be56a01
authored
Mar 19, 2025
by
Atream
Committed by
GitHub
Mar 19, 2025
Browse files
Merge pull request #927 from kvcache-ai/fix-gate-precision
Update gate.py
parents
6ca233cc
b453333f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
9 deletions
+11
-9
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+11
-9
No files found.
ktransformers/operators/gate.py
View file @
8be56a01
...
@@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor,
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
topk_group
:
int
=
0
,
routed_scaling_factor
:
float
=
1.0
,
scoring_func
:
str
=
"sigmoid"
,
scoring_func
:
str
=
"sigmoid"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
...
@@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor,
score_mask
=
group_mask
.
unsqueeze
(
-
1
).
expand
(
score_mask
=
group_mask
.
unsqueeze
(
-
1
).
expand
(
num_token
,
num_expert_group
,
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
).
reshape
(
num_token
,
-
1
)
# [n, e]
scores
.
shape
[
-
1
]
//
num_expert_group
).
reshape
(
num_token
,
-
1
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
float
(
"-inf"
))
# [n, e]
#
float("-inf")) # [n, e]
if
e_score_correction_bias
is
not
None
:
if
e_score_correction_bias
is
not
None
:
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
...
@@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor,
dim
=-
1
,
dim
=-
1
,
sorted
=
False
)
sorted
=
False
)
if
renormalize
:
if
topk
>
1
and
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
denominator
=
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
topk_weights
=
topk_weights
/
denominator
topk_weights
=
topk_weights
*
routed_scaling_factor
# must multiply the scaling factor
return
topk_ids
.
to
(
torch
.
long
),
topk_weights
.
to
(
torch
.
float32
)
return
topk_ids
.
to
(
torch
.
long
),
topk_weights
.
to
(
torch
.
float32
)
class
KMoEGateDeepSeekV3
(
BaseInjectedModule
,
KMoEGateBase
):
class
KMoEGateDeepSeekV3
(
BaseInjectedModule
,
KMoEGateBase
):
...
@@ -204,6 +206,7 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
...
@@ -204,6 +206,7 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
self
.
is_windows
=
os
.
name
==
'nt'
self
.
is_windows
=
os
.
name
==
'nt'
self
.
use_quant
=
use_quant
self
.
use_quant
=
use_quant
if
not
self
.
is_windows
and
use_quant
:
if
not
self
.
is_windows
and
use_quant
:
print
(
"injecting gate_linear"
)
self
.
gate_linear
=
nn
.
Linear
(
self
.
gating_dim
,
self
.
n_routed_experts
,
device
=
generate_device
)
self
.
gate_linear
=
nn
.
Linear
(
self
.
gating_dim
,
self
.
n_routed_experts
,
device
=
generate_device
)
self
.
gate_linear
=
KTransformersLinear
(
key
+
".ffn_gate_inp"
,
self
.
gate_linear
=
KTransformersLinear
(
key
+
".ffn_gate_inp"
,
gguf_loader
,
config
,
self
.
gate_linear
,
#orig_module
gguf_loader
,
config
,
self
.
gate_linear
,
#orig_module
...
@@ -219,14 +222,13 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
...
@@ -219,14 +222,13 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
### compute gating score
### compute gating score
hidden_states
=
hidden_states
.
view
(
-
1
,
h
)
hidden_states
=
hidden_states
.
view
(
-
1
,
h
)
if
self
.
use_quant
:
if
self
.
use_quant
:
logits
=
self
.
gate_linear
.
forward
(
logit
s
)
logits
=
self
.
gate_linear
.
forward
(
hidden_state
s
)
else
:
else
:
logits
=
F
.
linear
(
logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
),
None
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
),
None
)
)
return
grouped_topk
(
hidden_states
,
logits
,
self
.
top_k
,
self
.
norm_topk_prob
,
self
.
n_group
,
return
grouped_topk
(
hidden_states
,
logits
,
self
.
top_k
,
self
.
norm_topk_prob
,
self
.
topk_group
,
self
.
routed_scaling_factor
,
"sigmoid"
,
self
.
e_score_correction_bias
)
self
.
n_group
,
self
.
topk_group
,
"sigmoid"
,
self
.
e_score_correction_bias
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
device
is
None
:
device
=
self
.
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment