Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ab406446
Unverified
Commit
ab406446
authored
Jan 29, 2024
by
Philipp Moritz
Committed by
GitHub
Jan 29, 2024
Browse files
Fused MOE for Mixtral (#2542)
Co-authored-by:
chen shen
<
scv119@gmail.com
>
parent
5d60def0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
109 deletions
+115
-109
csrc/moe_align_block_size_kernels.cu
csrc/moe_align_block_size_kernels.cu
+1
-1
csrc/ops.h
csrc/ops.h
+7
-9
csrc/pybind.cpp
csrc/pybind.cpp
+3
-3
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+104
-96
No files found.
csrc/moe_align_block_size_kernels.cu
View file @
ab406446
...
...
@@ -95,7 +95,7 @@ void moe_align_block_size(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
assert
(
num_experts
<=
NUM_MAX_EXPERTS
);
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_alig_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_alig
n
_block_size_kernel"
,
[
&
]
{
vllm
::
moe_align_block_size_kernel
<
scalar_t
><<<
1
,
num_experts
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
...
...
csrc/ops.h
View file @
ab406446
...
...
@@ -100,6 +100,13 @@ void gptq_shuffle(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
using
fptr_t
=
uint64_t
;
...
...
@@ -121,12 +128,3 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
);
#endif
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
csrc/pybind.cpp
View file @
ab406446
...
...
@@ -57,9 +57,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
...
...
vllm/model_executor/models/mixtral.py
View file @
ab406446
...
...
@@ -23,8 +23,6 @@
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
...
...
@@ -33,10 +31,11 @@ from transformers import MixtralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
ReplicatedLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
...
...
@@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
...
...
@@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMLP
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBas
e
]
=
None
,
)
->
None
:
params_dtype
:
Optional
[
torch
.
dtyp
e
]
=
None
,
):
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
ffn_dim
=
intermediate_size
self
.
hidden_dim
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
w1_out
,
_
=
self
.
w1
(
hidden_states
)
w1_out
=
self
.
act_fn
(
w1_out
)
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
num_experts
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
//
tp_size
class
MixtralMoE
(
nn
.
Module
):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
num_total_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
num_total_experts
}
."
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
range
(
self
.
num_total_experts
),
self
.
tp_size
)[
self
.
rank
].
tolist
()
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
self
.
experts
=
nn
.
ModuleList
([
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
)
if
idx
in
self
.
expert_indicies
else
None
for
idx
in
range
(
self
.
num_total_experts
)
])
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
linear_method
=
None
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2s
,
{
"weight_loader"
:
self
.
weight_loader
,
})
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_
dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_
dim
)
batch_size
,
sequence_length
,
hidden_
size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_
size
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
...
@@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
None
for
expert_idx
in
self
.
expert_indicies
:
expert_layer
=
self
.
experts
[
expert_idx
]
expert_mask
=
(
selected_experts
==
expert_idx
)
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
keepdim
=
True
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
expert_weights
)
if
final_hidden_states
is
None
:
final_hidden_states
=
current_hidden_states
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
routing_weights
,
selected_experts
,
inplace
=
True
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
.
view
(
batch_size
,
sequence_length
,
hidden_
dim
)
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
hidden_
size
)
class
MixtralAttention
(
nn
.
Module
):
...
...
@@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
linear_method
=
linear_method
)
self
.
block_sparse_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
...
@@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
...
...
@@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
(
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment