Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ee8d3ba
Unverified
Commit
2ee8d3ba
authored
Jul 31, 2024
by
Avshalom Manevich
Committed by
GitHub
Jul 31, 2024
Browse files
[Model] use FusedMoE layer in Jamba (#6935)
parent
daed30c4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
108 deletions
+49
-108
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+49
-108
No files found.
vllm/model_executor/models/jamba.py
View file @
2ee8d3ba
# coding=utf-8
"""Inference-only J
urassic
model."""
"""Inference-only J
amba
model."""
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
...
...
@@ -15,10 +15,9 @@ from vllm.attention.backends.abstract import AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
f
used
_moe
from
vllm.model_executor.layers.fused_moe
import
F
used
MoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
...
...
@@ -282,108 +281,50 @@ class JambaMLP(nn.Module):
class
JambaMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
config
:
JambaConfig
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
def
__init__
(
self
,
config
:
JambaConfig
,
num_experts
:
Optional
[
int
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_experts
self
.
top_k
=
config
.
num_experts_per_tok
self
.
num_total_experts
=
num_experts
or
config
.
num_experts
self
.
top_k
=
top_k
or
config
.
num_experts_per_tok
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
//
self
.
tp_size
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
router
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
,
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
,
))
self
.
intermediate_size
=
config
.
intermediate_size
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
set_weight_attrs
(
self
.
w2s
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
,
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"gate_proj.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"up_proj.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"down_proj.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
self
.
num_total_experts
>
1
:
self
.
router
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
quant_config
=
None
,
params_dtype
=
params_dtype
)
self
.
experts
=
FusedMoE
(
self
.
num_total_experts
,
self
.
top_k
,
self
.
hidden_size
,
self
.
intermediate_size
,
tp_size
=
tp_size
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
renormalize
=
False
,
use_grouped_topk
=
False
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_siz
e
=
hidden_states
.
shape
orig_shap
e
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
router
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
router_logits
,
self
.
top_k
,
renormalize
=
False
,
# Mixtral normalize the expert probs to 1. We don't!
inplace
=
True
,
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
if
self
.
num_total_experts
>
1
:
router_logits
,
_
=
self
.
router
(
hidden_states
)
else
:
router_logits
=
torch
.
ones
((
hidden_states
.
shape
[
0
],
1
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
return
hidden_states
.
view
(
orig_shape
)
class
JambaMambaDecoderLayer
(
nn
.
Module
):
...
...
@@ -917,15 +858,13 @@ class JambaForCausalLM(nn.Module, HasInnerState):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"gate_proj"
,
"up_proj"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_experts
)
for
weight_name
in
[
"down_proj"
,
"up_proj"
,
"gate_proj"
]
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
...
...
@@ -952,7 +891,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
...
...
@@ -961,6 +901,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment