Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
ab406446
Unverified
Commit
ab406446
authored
Jan 29, 2024
by
Philipp Moritz
Committed by
GitHub
Jan 29, 2024
Browse files
Fused MOE for Mixtral (#2542)
Co-authored-by:
chen shen
<
scv119@gmail.com
>
parent
5d60def0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
109 deletions
+115
-109
csrc/moe_align_block_size_kernels.cu
csrc/moe_align_block_size_kernels.cu
+1
-1
csrc/ops.h
csrc/ops.h
+7
-9
csrc/pybind.cpp
csrc/pybind.cpp
+3
-3
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+104
-96
No files found.
csrc/moe_align_block_size_kernels.cu
View file @
ab406446
...
@@ -95,7 +95,7 @@ void moe_align_block_size(
...
@@ -95,7 +95,7 @@ void moe_align_block_size(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
assert
(
num_experts
<=
NUM_MAX_EXPERTS
);
assert
(
num_experts
<=
NUM_MAX_EXPERTS
);
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_alig_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_alig
n
_block_size_kernel"
,
[
&
]
{
vllm
::
moe_align_block_size_kernel
<
scalar_t
><<<
1
,
num_experts
,
0
,
stream
>>>
(
vllm
::
moe_align_block_size_kernel
<
scalar_t
><<<
1
,
num_experts
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
...
...
csrc/ops.h
View file @
ab406446
...
@@ -100,6 +100,13 @@ void gptq_shuffle(
...
@@ -100,6 +100,13 @@ void gptq_shuffle(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
);
torch
::
Tensor
q_perm
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
#ifndef USE_ROCM
using
fptr_t
=
uint64_t
;
using
fptr_t
=
uint64_t
;
...
@@ -121,12 +128,3 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
...
@@ -121,12 +128,3 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
);
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
);
#endif
#endif
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
csrc/pybind.cpp
View file @
ab406446
...
@@ -57,9 +57,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -57,9 +57,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
ops
.
def
(
"moe_align_block_size"
,
"moe_align_block_size"
,
&
moe_align_block_size
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
// Cache ops
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
...
...
vllm/model_executor/models/mixtral.py
View file @
ab406446
...
@@ -23,8 +23,6 @@
...
@@ -23,8 +23,6 @@
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -33,10 +31,11 @@ from transformers import MixtralConfig
...
@@ -33,10 +31,11 @@ from transformers import MixtralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
ReplicatedLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
...
@@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
...
@@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
...
@@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMLP
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
def
__init__
(
self
,
self
,
num_experts
:
int
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBas
e
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtyp
e
]
=
None
,
)
->
None
:
):
super
().
__init__
()
super
().
__init__
()
self
.
num_experts
=
num_experts
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
ffn_dim
=
intermediate_size
self
.
num_total_experts
=
num_experts
self
.
hidden_dim
=
hidden_size
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
intermediate_size
=
intermediate_size
//
tp_size
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
w1_out
,
_
=
self
.
w1
(
hidden_states
)
w1_out
=
self
.
act_fn
(
w1_out
)
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
class
MixtralMoE
(
nn
.
Module
):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
def
__init__
(
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
num_total_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
num_total_experts
}
."
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
range
(
self
.
num_total_experts
),
self
.
tp_size
)[
self
.
rank
].
tolist
()
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
self
.
experts
=
nn
.
ModuleList
([
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
)
if
idx
in
self
.
expert_indicies
else
None
for
idx
in
range
(
self
.
num_total_experts
)
])
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
self
.
num_total_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
linear_method
=
None
)
linear_method
=
None
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2s
,
{
"weight_loader"
:
self
.
weight_loader
,
})
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_
dim
=
hidden_states
.
shape
batch_size
,
sequence_length
,
hidden_
size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_
dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_
size
)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
@@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
...
@@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
dim
=-
1
)
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
None
final_hidden_states
=
fused_moe
(
hidden_states
,
for
expert_idx
in
self
.
expert_indicies
:
self
.
ws
,
expert_layer
=
self
.
experts
[
expert_idx
]
self
.
w2s
,
expert_mask
=
(
selected_experts
==
expert_idx
)
routing_weights
,
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
selected_experts
,
keepdim
=
True
)
inplace
=
True
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
expert_weights
)
final_hidden_states
)
if
final_hidden_states
is
None
:
final_hidden_states
=
current_hidden_states
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
.
view
(
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
batch_size
,
sequence_length
,
hidden_
dim
)
hidden_
size
)
class
MixtralAttention
(
nn
.
Module
):
class
MixtralAttention
(
nn
.
Module
):
...
@@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
self
.
block_sparse_moe
=
MixtralMoE
(
linear_method
=
linear_method
)
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
...
@@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
model_name_or_path
,
...
@@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
fall_back_to_pt
=
False
):
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
@@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
...
@@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
weight_name
not
in
name
:
continue
continue
# Skip experts that are not assigned to this worker.
name
=
name
.
replace
(
weight_name
,
param_name
)
if
(
"block_sparse_moe.experts."
in
name
param
=
params_dict
[
name
]
and
name
not
in
params_dict
):
weight_loader
=
param
.
weight_loader
continue
weight_loader
(
param
,
param
=
params_dict
[
name
]
loaded_weight
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_name
,
default_weight_loader
)
expert_id
=
expert_id
)
weight_loader
(
param
,
loaded_weight
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment