Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3f14cd14
Unverified
Commit
3f14cd14
authored
Sep 24, 2024
by
Daniël de Kok
Committed by
GitHub
Sep 24, 2024
Browse files
Add `DenseMoELayer` and wire it up in Mixtral/Deepseek V2 (#2537)
This replaces the custom layers in both models.
parent
c29dc89c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
211 additions
and
220 deletions
+211
-220
server/text_generation_server/layers/moe/__init__.py
server/text_generation_server/layers/moe/__init__.py
+164
-1
server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
...rver/models/custom_modeling/flash_deepseek_v2_modeling.py
+22
-105
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
...n_server/models/custom_modeling/flash_mixtral_modeling.py
+25
-114
No files found.
server/text_generation_server/layers/moe/__init__.py
View file @
3f14cd14
from
typing
import
Optional
from
typing
import
Optional
,
Protocol
,
runtime_checkable
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
loguru
import
logger
from
transformers.activations
import
ACT2FN
from
text_generation_server.layers
import
(
TensorParallelColumnLinear
,
TensorParallelRowLinear
,
)
from
text_generation_server.layers.fp8
import
HybridFP8UnquantLoader
from
text_generation_server.layers.fp8
import
HybridFP8UnquantLoader
from
text_generation_server.layers.moe.unquantized
import
UnquantizedSparseMoELayer
from
text_generation_server.layers.moe.unquantized
import
UnquantizedSparseMoELayer
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.utils.log
import
log_once
from
text_generation_server.utils.weights
import
(
from
text_generation_server.utils.weights
import
(
DefaultWeightsLoader
,
DefaultWeightsLoader
,
UnquantizedWeight
,
UnquantizedWeight
,
Weights
,
Weights
,
)
)
if
SYSTEM
!=
"ipex"
:
from
moe_kernels.fused_moe
import
fused_topk
,
grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete
# class inheritance is whacky.
@
runtime_checkable
class
MoELayer
(
Protocol
):
def
__init__
(
self
,
*
,
n_expert_group
:
Optional
[
int
],
n_experts
:
int
,
prefix
:
str
,
renormalize
:
bool
,
topk
:
int
,
topk_group
:
Optional
[
int
],
weights
:
Weights
,
gate_proj_name
:
str
=
"gate_proj"
,
up_proj_name
:
str
=
"up_proj"
,
down_proj_name
:
str
=
"down_proj"
,
hidden_act
:
str
=
"silu"
,
):
...
def
forward
(
self
,
x
:
torch
.
Tensor
,
*
,
gating_output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
class
DenseMoELayer
(
nn
.
Module
):
"""
Layer for MoE that applies *all* experts to each tokens and then weights
their outputs based on the calculated routing. This layer is much slower
than `SparseMoELayer` and should only be used when no fused kernels are
available (e.g. for unsupported quantizers).
"""
def
__init__
(
self
,
*
,
n_expert_group
:
Optional
[
int
],
n_experts
:
int
,
prefix
:
str
,
renormalize
:
bool
,
topk
:
int
,
topk_group
:
Optional
[
int
],
weights
:
Weights
,
gate_proj_name
:
str
=
"gate_proj"
,
up_proj_name
:
str
=
"up_proj"
,
down_proj_name
:
str
=
"down_proj"
,
hidden_act
:
str
=
"silu"
,
):
super
().
__init__
()
log_once
(
logger
.
info
,
"No fused layers are available for this model type, using (slower) dense MoE layer"
,
)
assert
(
n_expert_group
is
None
)
==
(
topk_group
is
None
),
"n_expert_group and topk_group must both be None or have some value"
self
.
n_expert_group
=
n_expert_group
self
.
n_experts
=
n_experts
self
.
renormalize
=
renormalize
self
.
topk
=
topk
self
.
topk_group
=
topk_group
if
"gelu"
in
hidden_act
:
self
.
act
=
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
hidden_act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
elif
"silu"
in
hidden_act
:
self
.
act
=
torch
.
nn
.
functional
.
silu
else
:
self
.
act
=
ACT2FN
[
hidden_act
]
self
.
gate_proj
=
[
TensorParallelColumnLinear
.
load
(
None
,
prefix
=
f
"
{
prefix
}
.
{
i
}
.
{
gate_proj_name
}
"
,
weights
=
weights
,
bias
=
False
,
)
for
i
in
range
(
self
.
n_experts
)
]
self
.
up_proj
=
[
TensorParallelColumnLinear
.
load
(
None
,
prefix
=
f
"
{
prefix
}
.
{
i
}
.
{
up_proj_name
}
"
,
weights
=
weights
,
bias
=
False
,
)
for
i
in
range
(
self
.
n_experts
)
]
self
.
down_proj
=
[
TensorParallelRowLinear
.
load
(
None
,
prefix
=
f
"
{
prefix
}
.
{
i
}
.
{
down_proj_name
}
"
,
weights
=
weights
,
bias
=
False
,
)
for
i
in
range
(
self
.
n_experts
)
]
self
.
process_group
=
weights
.
process_group
def
forward
(
self
,
x
:
torch
.
Tensor
,
*
,
gating_output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: (sequence_length, model_dim)
gating_output: (sequence_length, n_experts)
"""
# optional reshape
input_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
input_shape
[
-
1
])
if
self
.
n_expert_group
is
not
None
and
self
.
topk_group
is
not
None
:
topk_weights
,
topk_ids
=
grouped_topk
(
x
,
gating_output
,
self
.
topk
,
renormalize
=
self
.
renormalize
,
num_expert_group
=
self
.
n_expert_group
,
topk_group
=
self
.
topk_group
,
)
else
:
topk_weights
,
topk_ids
=
fused_topk
(
x
,
gating_output
,
self
.
topk
,
self
.
renormalize
)
topk_weights
=
topk_weights
.
to
(
x
.
dtype
)
weights
=
torch
.
zeros
(
topk_ids
.
shape
[
0
],
self
.
n_experts
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
weights
.
scatter_
(
1
,
topk_ids
.
long
(),
topk_weights
.
to
(
weights
.
dtype
))
out
=
torch
.
zeros_like
(
x
)
for
i
in
range
(
self
.
n_experts
):
h
=
self
.
act
(
self
.
gate_proj
[
i
](
x
))
*
self
.
up_proj
[
i
](
x
)
h
=
self
.
down_proj
[
i
](
h
,
reduce
=
False
)
out
+=
h
*
weights
[:,
i
].
view
(
-
1
,
1
)
return
out
class
SparseMoELayer
(
nn
.
Module
):
class
SparseMoELayer
(
nn
.
Module
):
"""
"""
...
...
server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
View file @
3f14cd14
...
@@ -13,10 +13,14 @@
...
@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
text_generation_server.layers
import
(
from
text_generation_server.layers
import
(
FastLinear
,
FastLinear
,
SpeculativeHead
,
SpeculativeHead
,
...
@@ -26,22 +30,16 @@ from text_generation_server.layers import (
...
@@ -26,22 +30,16 @@ from text_generation_server.layers import (
get_linear
,
get_linear
,
)
)
from
text_generation_server.layers.attention
import
(
from
text_generation_server.layers.attention
import
(
Seqlen
,
attention
,
attention
,
paged_attention
,
paged_attention
,
reshape_and_cache
,
reshape_and_cache
,
Seqlen
,
)
)
from
text_generation_server.layers.layernorm
import
FastRMSNorm
from
text_generation_server.layers.layernorm
import
FastRMSNorm
from
text_generation_server.layers.moe
import
SparseMoELayer
from
text_generation_server.layers.moe
import
DenseMoELayer
,
MoELayer
,
SparseMoELayer
from
text_generation_server.layers.rotary
import
PositionRotaryEmbedding
,
get_mscale
from
text_generation_server.layers.rotary
import
PositionRotaryEmbedding
,
get_mscale
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.utils.weights
import
Weights
from
text_generation_server.utils.weights
import
Weights
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
if
SYSTEM
!=
"ipex"
:
from
moe_kernels.fused_moe
import
grouped_topk
if
SYSTEM
==
"rocm"
:
if
SYSTEM
==
"rocm"
:
try
:
try
:
...
@@ -410,8 +408,14 @@ class DeepseekV2MLP(nn.Module):
...
@@ -410,8 +408,14 @@ class DeepseekV2MLP(nn.Module):
)
)
class
BlockSparseMoE
(
nn
.
Module
):
class
DeepseekV2MoE
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
:
DeepseekV2Config
,
weights
):
def
__init__
(
self
,
prefix
,
config
:
DeepseekV2Config
,
moe_layer_cls
:
Type
[
MoELayer
],
weights
,
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_dim
=
config
.
hidden_size
self
.
hidden_dim
=
config
.
hidden_size
...
@@ -423,7 +427,7 @@ class BlockSparseMoE(nn.Module):
...
@@ -423,7 +427,7 @@ class BlockSparseMoE(nn.Module):
# Gating
# Gating
self
.
gate
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.gate"
,
weights
,
bias
=
False
)
self
.
gate
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.gate"
,
weights
,
bias
=
False
)
self
.
moe_layer
=
SparseMoELayer
(
self
.
moe_layer
=
moe_layer_cls
(
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
n_experts
=
config
.
n_routed_experts
,
n_experts
=
config
.
n_routed_experts
,
n_expert_group
=
config
.
n_group
,
n_expert_group
=
config
.
n_group
,
...
@@ -432,6 +436,7 @@ class BlockSparseMoE(nn.Module):
...
@@ -432,6 +436,7 @@ class BlockSparseMoE(nn.Module):
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
weights
=
weights
,
weights
=
weights
,
)
)
assert
isinstance
(
self
.
moe_layer
,
MoELayer
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
self
.
shared_experts
=
DeepseekV2MLP
(
self
.
shared_experts
=
DeepseekV2MLP
(
...
@@ -466,96 +471,6 @@ class BlockSparseMoE(nn.Module):
...
@@ -466,96 +471,6 @@ class BlockSparseMoE(nn.Module):
return
out
.
view
(
*
x
.
shape
)
return
out
.
view
(
*
x
.
shape
)
class
DenseMoE
(
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
:
DeepseekV2Config
,
weights
:
Weights
):
super
().
__init__
()
self
.
hidden_dim
=
config
.
hidden_size
self
.
moe_intermediate_size
=
config
.
moe_intermediate_size
self
.
n_routed_experts
=
config
.
n_routed_experts
self
.
n_expert_group
=
config
.
n_group
self
.
topk_group
=
config
.
topk_group
self
.
top_k
=
config
.
num_experts_per_tok
self
.
norm_topk_prob
=
config
.
norm_topk_prob
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
# Gating
#
# Seems like no one quantizes the gate.
self
.
gate
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.gate"
,
weights
,
bias
=
False
)
self
.
experts
=
[
DeepseekV2MLP
(
f
"
{
prefix
}
.experts.
{
i
}
"
,
config
,
weights
,
self
.
moe_intermediate_size
)
for
i
in
range
(
self
.
n_routed_experts
)
]
if
config
.
n_shared_experts
is
not
None
:
self
.
shared_experts
=
DeepseekV2MLP
(
prefix
=
f
"
{
prefix
}
.shared_experts"
,
config
=
config
,
weights
=
weights
,
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
,
)
else
:
self
.
shared_experts
=
None
self
.
process_group
=
weights
.
process_group
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
input_shape
[
-
1
])
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
x
,
reduce
=
False
)
else
:
shared_output
=
None
# gate_logits: (sequence_length, n_experts)
router_logits
=
self
.
gate
(
x
)
topk_weights
,
topk_ids
=
grouped_topk
(
x
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
norm_topk_prob
,
num_expert_group
=
self
.
n_expert_group
,
topk_group
=
self
.
topk_group
,
)
out
=
self
.
moe_infer_gpu
(
x
,
topk_ids
,
topk_weights
)
*
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
out
=
out
+
shared_output
# Reduce sum
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
def
moe_infer_gpu
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
):
weights
=
torch
.
zeros
(
topk_ids
.
shape
[
0
],
len
(
self
.
experts
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
weights
.
scatter_
(
1
,
topk_ids
,
topk_weight
)
out
=
x
.
new_zeros
(
x
.
shape
[
0
],
self
.
hidden_dim
)
for
i
,
expert
in
enumerate
(
self
.
experts
):
# Add expert output to out with masking
out
+=
expert
(
x
,
reduce
=
False
)
*
weights
[:,
i
].
view
(
-
1
,
1
)
return
out
class
DeepseekV2Layer
(
nn
.
Module
):
class
DeepseekV2Layer
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
layer_id
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
layer_id
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
...
@@ -572,10 +487,12 @@ class DeepseekV2Layer(nn.Module):
...
@@ -572,10 +487,12 @@ class DeepseekV2Layer(nn.Module):
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
%
config
.
moe_layer_freq
==
0
and
layer_id
%
config
.
moe_layer_freq
==
0
):
):
moe_cls
=
(
moe_layer_cls
=
(
BlockSparseMoE
if
SparseMoELayer
.
is_supported
(
weights
)
else
DenseMoE
SparseMoELayer
if
SparseMoELayer
.
is_supported
(
weights
)
else
DenseMoELayer
)
)
self
.
mlp
=
moe_cls
(
f
"
{
prefix
}
.mlp"
,
config
,
weights
)
self
.
mlp
=
DeepseekV2MoE
(
f
"
{
prefix
}
.mlp"
,
config
,
moe_layer_cls
,
weights
)
else
:
else
:
self
.
mlp
=
DeepseekV2MLP
(
self
.
mlp
=
DeepseekV2MLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
prefix
=
f
"
{
prefix
}
.mlp"
,
...
...
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
View file @
3f14cd14
...
@@ -18,38 +18,31 @@
...
@@ -18,38 +18,31 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
torch
import
nn
from
torch
import
nn
from
text_generation_server.utils.import_utils
import
SYSTEM
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.layers.attention
import
(
paged_attention
,
attention
,
reshape_and_cache
,
Seqlen
,
)
from
text_generation_server.layers
import
(
from
text_generation_server.layers
import
(
FastLinear
,
FastLinear
,
TensorParallelRowLin
ea
r
,
SpeculativeH
ea
d
,
TensorParallelColumnLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
TensorParallelEmbedding
,
SpeculativeH
ea
d
,
TensorParallelRowLin
ea
r
,
get_linear
,
get_linear
,
)
)
from
text_generation_server.layers.moe
import
SparseMoELayer
from
text_generation_server.layers.attention
import
(
from
text_generation_server.layers.layernorm
import
(
Seqlen
,
FastRMSNorm
,
attention
,
)
paged_attention
,
from
text_generation_server.layers.rotary
import
(
reshape_and_cache
,
PositionRotaryEmbedding
,
)
)
from
text_generation_server.layers.layernorm
import
FastRMSNorm
from
text_generation_server.layers.moe
import
DenseMoELayer
,
MoELayer
,
SparseMoELayer
from
text_generation_server.layers.rotary
import
PositionRotaryEmbedding
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.utils.weights
import
UnquantizedWeight
from
text_generation_server.utils.weights
import
UnquantizedWeight
...
@@ -315,14 +308,16 @@ def round_up(x: torch.Tensor, value: int):
...
@@ -315,14 +308,16 @@ def round_up(x: torch.Tensor, value: int):
return
torch
.
div
(
x
+
(
value
-
1
),
value
,
rounding_mode
=
"trunc"
)
*
value
return
torch
.
div
(
x
+
(
value
-
1
),
value
,
rounding_mode
=
"trunc"
)
*
value
class
BlockSparseMoE
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
:
MixtralConfig
,
weights
):
def
__init__
(
self
,
prefix
,
config
:
MixtralConfig
,
moe_layer_cls
:
Type
[
MoELayer
],
weights
):
super
().
__init__
()
super
().
__init__
()
# gating
# gating
self
.
gate
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.gate"
,
weights
,
bias
=
False
)
self
.
gate
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.gate"
,
weights
,
bias
=
False
)
self
.
moe
=
SparseMoELayer
(
self
.
moe
=
moe_layer_cls
(
n_expert_group
=
None
,
n_expert_group
=
None
,
n_experts
=
config
.
num_local_experts
,
n_experts
=
config
.
num_local_experts
,
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
...
@@ -334,6 +329,7 @@ class BlockSparseMoE(nn.Module):
...
@@ -334,6 +329,7 @@ class BlockSparseMoE(nn.Module):
up_proj_name
=
"w3"
,
up_proj_name
=
"w3"
,
down_proj_name
=
"w2"
,
down_proj_name
=
"w2"
,
)
)
assert
isinstance
(
self
.
moe
,
MoELayer
)
self
.
process_group
=
weights
.
process_group
self
.
process_group
=
weights
.
process_group
...
@@ -349,95 +345,6 @@ class BlockSparseMoE(nn.Module):
...
@@ -349,95 +345,6 @@ class BlockSparseMoE(nn.Module):
return
out
.
view
(
*
x
.
shape
)
return
out
.
view
(
*
x
.
shape
)
class
DenseMoE
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
:
MixtralConfig
,
weights
):
super
().
__init__
()
self
.
hidden_dim
=
config
.
hidden_size
self
.
ffn_dim
=
config
.
intermediate_size
//
weights
.
process_group
.
size
()
self
.
num_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
act
=
config
.
hidden_act
if
"gelu"
in
act
:
self
.
act
=
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
elif
"silu"
in
act
:
self
.
act
=
torch
.
nn
.
functional
.
silu
else
:
self
.
act
=
ACT2FN
[
act
]
# gating
self
.
gate
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.gate"
,
weights
,
bias
=
False
)
self
.
w1
=
[
TensorParallelColumnLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.experts.
{
i
}
.w1"
,
weights
=
weights
,
bias
=
False
)
for
i
in
range
(
self
.
num_experts
)
]
self
.
w3
=
[
TensorParallelColumnLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.experts.
{
i
}
.w3"
,
weights
=
weights
,
bias
=
False
)
for
i
in
range
(
self
.
num_experts
)
]
self
.
w2
=
[
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.experts.
{
i
}
.w2"
,
weights
=
weights
,
bias
=
False
)
for
i
in
range
(
self
.
num_experts
)
]
self
.
process_group
=
weights
.
process_group
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
input_shape
[
-
1
])
# gate_logits: (sequence_length, n_experts)
gate_logits
=
self
.
gate
(
x
)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs
=
torch
.
nn
.
functional
.
softmax
(
gate_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
if
self
.
top_k
<
self
.
num_experts
:
_
,
not_selected_experts
=
torch
.
topk
(
all_probs
,
self
.
num_experts
-
self
.
top_k
,
largest
=
False
,
sorted
=
False
,
dim
=
1
,
)
# Mask not selected experts
all_probs
.
scatter_
(
1
,
not_selected_experts
,
0
)
# Re-normalize
weights
=
all_probs
/
all_probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
weights
=
weights
.
to
(
x
.
dtype
)
# Final output tensor
out
=
x
.
new_zeros
(
x
.
shape
[
0
],
self
.
hidden_dim
)
for
i
in
range
(
self
.
num_experts
):
h
=
self
.
act
(
self
.
w1
[
i
](
x
))
*
self
.
w3
[
i
](
x
)
h
=
self
.
w2
[
i
](
h
,
reduce
=
False
)
# Add expert output to out with masking
out
+=
h
*
weights
[:,
i
].
view
(
-
1
,
1
)
# Reduce sum
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
MixtralLayer
(
nn
.
Module
):
class
MixtralLayer
(
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
layer_id
,
config
,
weights
):
def
__init__
(
self
,
prefix
:
str
,
layer_id
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
...
@@ -447,8 +354,12 @@ class MixtralLayer(nn.Module):
...
@@ -447,8 +354,12 @@ class MixtralLayer(nn.Module):
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
)
moe_cls
=
BlockSparseMoE
if
SparseMoELayer
.
is_supported
(
weights
)
else
DenseMoE
moe_layer_cls
=
(
self
.
moe
=
moe_cls
(
f
"
{
prefix
}
.block_sparse_moe"
,
config
,
weights
)
SparseMoELayer
if
SparseMoELayer
.
is_supported
(
weights
)
else
DenseMoELayer
)
self
.
moe
=
MixtralMoE
(
f
"
{
prefix
}
.block_sparse_moe"
,
config
,
moe_layer_cls
,
weights
)
self
.
input_layernorm
=
FastRMSNorm
.
load
(
self
.
input_layernorm
=
FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment