Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ab406446
Unverified
Commit
ab406446
authored
Jan 29, 2024
by
Philipp Moritz
Committed by
GitHub
Jan 29, 2024
Browse files
Fused MOE for Mixtral (#2542)
Co-authored-by:
chen shen
<
scv119@gmail.com
>
parent
5d60def0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
109 deletions
+115
-109
csrc/moe_align_block_size_kernels.cu
csrc/moe_align_block_size_kernels.cu
+1
-1
csrc/ops.h
csrc/ops.h
+7
-9
csrc/pybind.cpp
csrc/pybind.cpp
+3
-3
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+104
-96
No files found.
csrc/moe_align_block_size_kernels.cu
View file @
ab406446
...
@@ -95,7 +95,7 @@ void moe_align_block_size(
...
@@ -95,7 +95,7 @@ void moe_align_block_size(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
assert
(
num_experts
<=
NUM_MAX_EXPERTS
);
assert
(
num_experts
<=
NUM_MAX_EXPERTS
);
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_alig_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_alig
n
_block_size_kernel"
,
[
&
]
{
vllm
::
moe_align_block_size_kernel
<
scalar_t
><<<
1
,
num_experts
,
0
,
stream
>>>
(
vllm
::
moe_align_block_size_kernel
<
scalar_t
><<<
1
,
num_experts
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
...
...
csrc/ops.h
View file @
ab406446
...
@@ -100,6 +100,13 @@ void gptq_shuffle(
...
@@ -100,6 +100,13 @@ void gptq_shuffle(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
);
torch
::
Tensor
q_perm
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
#ifndef USE_ROCM
using
fptr_t
=
uint64_t
;
using
fptr_t
=
uint64_t
;
...
@@ -121,12 +128,3 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
...
@@ -121,12 +128,3 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
);
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
);
#endif
#endif
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
csrc/pybind.cpp
View file @
ab406446
...
@@ -57,9 +57,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -57,9 +57,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
ops
.
def
(
"moe_align_block_size"
,
"moe_align_block_size"
,
&
moe_align_block_size
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
// Cache ops
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
...
...
vllm/model_executor/models/mixtral.py
View file @
ab406446
...
@@ -23,8 +23,6 @@
...
@@ -23,8 +23,6 @@
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -33,10 +31,11 @@ from transformers import MixtralConfig
...
@@ -33,10 +31,11 @@ from transformers import MixtralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
ReplicatedLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
...
@@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
...
@@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
...
@@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMLP
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
def
__init__
(
self
,
self
,
num_experts
:
int
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBas
e
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtyp
e
]
=
None
,
)
->
None
:
):
super
().
__init__
()
super
().
__init__
()
self
.
num_experts
=
num_experts
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
ffn_dim
=
intermediate_size
self
.
num_total_experts
=
num_experts
self
.
hidden_dim
=
hidden_size
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
intermediate_size
=
intermediate_size
//
tp_size
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
w1_out
,
_
=
self
.
w1
(
hidden_states
)
w1_out
=
self
.
act_fn
(
w1_out
)
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
class
MixtralMoE
(
nn
.
Module
):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
def
__init__
(
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
num_total_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
num_total_experts
}
."
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
range
(
self
.
num_total_experts
),
self
.
tp_size
)[
self
.
rank
].
tolist
()
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
self
.
experts
=
nn
.
ModuleList
([
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
)
if
idx
in
self
.
expert_indicies
else
None
for
idx
in
range
(
self
.
num_total_experts
)
])
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
self
.
num_total_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
linear_method
=
None
)
linear_method
=
None
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2s
,
{
"weight_loader"
:
self
.
weight_loader
,
})
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_
dim
=
hidden_states
.
shape
batch_size
,
sequence_length
,
hidden_
size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_
dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_
size
)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
@@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
...
@@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
dim
=-
1
)
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
None
final_hidden_states
=
fused_moe
(
hidden_states
,
for
expert_idx
in
self
.
expert_indicies
:
self
.
ws
,
expert_layer
=
self
.
experts
[
expert_idx
]
self
.
w2s
,
expert_mask
=
(
selected_experts
==
expert_idx
)
routing_weights
,
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
selected_experts
,
keepdim
=
True
)
inplace
=
True
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
expert_weights
)
final_hidden_states
)
if
final_hidden_states
is
None
:
final_hidden_states
=
current_hidden_states
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
.
view
(
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
batch_size
,
sequence_length
,
hidden_
dim
)
hidden_
size
)
class
MixtralAttention
(
nn
.
Module
):
class
MixtralAttention
(
nn
.
Module
):
...
@@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
self
.
block_sparse_moe
=
MixtralMoE
(
linear_method
=
linear_method
)
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
...
@@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
model_name_or_path
,
...
@@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
fall_back_to_pt
=
False
):
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
@@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
...
@@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
weight_name
not
in
name
:
continue
continue
# Skip experts that are not assigned to this worker.
name
=
name
.
replace
(
weight_name
,
param_name
)
if
(
"block_sparse_moe.experts."
in
name
param
=
params_dict
[
name
]
and
name
not
in
params_dict
):
weight_loader
=
param
.
weight_loader
continue
weight_loader
(
param
,
param
=
params_dict
[
name
]
loaded_weight
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_name
,
default_weight_loader
)
expert_id
=
expert_id
)
weight_loader
(
param
,
loaded_weight
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment