Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9dc7c6c7
Unverified
Commit
9dc7c6c7
authored
Sep 21, 2024
by
Divakar Verma
Committed by
GitHub
Sep 21, 2024
Browse files
[dbrx] refactor dbrx experts to extend FusedMoe class (#8518)
parent
ec4aaad8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
69 deletions
+51
-69
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+51
-69
No files found.
vllm/model_executor/models/dbrx.py
View file @
9dc7c6c7
...
...
@@ -7,9 +7,8 @@ import torch.nn as nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.fused_moe
import
fused_moe
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
...
...
@@ -22,7 +21,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
...
...
@@ -54,13 +52,7 @@ class DbrxRouter(nn.Module):
return
router_logits
class
DbrxExperts
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
class
DbrxExperts
(
FusedMoE
):
def
__init__
(
self
,
...
...
@@ -68,49 +60,24 @@ class DbrxExperts(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
super
().
__init__
()
super
().
__init__
(
num_experts
=
config
.
ffn_config
.
moe_num_experts
,
top_k
=
config
.
ffn_config
.
moe_top_k
,
hidden_size
=
config
.
d_model
,
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
tp_size
=
get_tensor_model_parallel_world_size
(),
)
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
ffn_config
.
moe_num_experts
self
.
top_k
=
config
.
ffn_config
.
moe_top_k
self
.
d_model
=
config
.
d_model
self
.
intermediate_size
=
(
config
.
ffn_config
.
ffn_hidden_size
//
self
.
intermediate_size
=
(
self
.
config
.
ffn_config
.
ffn_hidden_size
//
self
.
tp_size
)
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
router
=
DbrxRouter
(
config
,
self
.
params_dtype
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
d_model
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
,
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
d_model
,
self
.
intermediate_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
,
))
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
set_weight_attrs
(
self
.
w2s
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
# Define custom weight loader for dbrx model
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
):
tp_rank
=
get_tensor_model_parallel_rank
()
...
...
@@ -140,26 +107,40 @@ class DbrxExperts(nn.Module):
).
transpose
(
1
,
2
)
param_data
[:]
=
loaded_weight
[:,
:,
shard
]
class
DbrxMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
config
:
DbrxConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
router
=
DbrxRouter
(
config
,
self
.
params_dtype
)
self
.
experts
=
DbrxExperts
(
config
=
config
,
quant_config
=
quant_config
,
params_dtype
=
self
.
params_dtype
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_siz
e
=
hidden_states
.
shape
orig_shap
e
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
d_model
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
router
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
router_logits
,
self
.
top_k
,
renormalize
=
True
,
inplace
=
True
,
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
return
final_hidden_states
.
view
(
orig_shape
)
class
DbrxAttention
(
nn
.
Module
):
...
...
@@ -288,7 +269,7 @@ class DbrxBlock(nn.Module):
super
().
__init__
()
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
cache_config
,
quant_config
)
self
.
ffn
=
Dbrx
Experts
(
config
,
quant_config
)
self
.
ffn
=
Dbrx
MoE
(
config
,
quant_config
)
def
forward
(
self
,
...
...
@@ -409,9 +390,10 @@ class DbrxForCausalLM(nn.Module):
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
expert_params_mapping
=
[(
"w
s
"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2
s
"
,
f
"
experts.
mlp.
{
weight_name
}
"
,
"w
13_weight
"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2
_weight
"
,
f
"mlp.
{
weight_name
}
"
,
)
for
weight_name
in
[
"w1"
,
"v1"
,
"w2"
]]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment