Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
f0d4e145
Unverified
Commit
f0d4e145
authored
Feb 05, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 05, 2024
Browse files
Add fused top-K softmax kernel for MoE (#2769)
parent
2ccee3de
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
591 additions
and
50 deletions
+591
-50
csrc/moe/moe_ops.cpp
csrc/moe/moe_ops.cpp
+7
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+9
-0
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+499
-0
csrc/pybind.cpp
csrc/pybind.cpp
+1
-1
setup.py
setup.py
+11
-0
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+10
-16
vllm/model_executor/layers/fused_moe.py
vllm/model_executor/layers/fused_moe.py
+48
-10
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+3
-12
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+3
-11
No files found.
csrc/moe/moe_ops.cpp
0 → 100644
View file @
f0d4e145
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
}
csrc/moe/moe_ops.h
0 → 100644
View file @
f0d4e145
#pragma once
#include <torch/extension.h>
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
csrc/moe/topk_softmax_kernels.cu
0 → 100644
View file @
f0d4e145
This diff is collapsed.
Click to expand it.
csrc/pybind.cpp
View file @
f0d4e145
...
...
@@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
#ifndef USE_ROCM
// Quantization ops
#ifndef USE_ROCM
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
...
...
setup.py
View file @
f0d4e145
...
...
@@ -339,6 +339,17 @@ if _is_cuda():
vllm_extension_sources
.
append
(
"csrc/quantization/awq/gemm_kernels.cu"
)
vllm_extension_sources
.
append
(
"csrc/custom_all_reduce.cu"
)
# Add MoE kernels.
ext_modules
.
append
(
CUDAExtension
(
name
=
"vllm._moe_C"
,
sources
=
glob
(
"csrc/moe/*.cu"
)
+
glob
(
"csrc/moe/*.cpp"
),
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
))
if
not
_is_neuron
():
vllm_extension
=
CUDAExtension
(
name
=
"vllm._C"
,
...
...
tests/kernels/test_moe.py
View file @
f0d4e145
...
...
@@ -2,10 +2,8 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import
pytest
import
torch
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
...
...
@@ -14,22 +12,21 @@ from vllm.model_executor.layers.activation import SiluAndMul
from
vllm.model_executor.models.mixtral
import
MixtralMoE
def
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk
_ids
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk_ids
.
shape
[
1
],
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
topk_ids
=
topk_ids
.
view
(
-
1
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
)).
sum
(
dim
=
1
)
topk_weight
.
view
(
B
,
-
1
,
1
)
.
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
...
...
@@ -51,11 +48,8 @@ def test_fused_moe(
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
'cuda'
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
'cuda'
,
dtype
=
dtype
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
...
...
@@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
tp_size
=
1
,
)
)
.
cuda
()
# Load the weights
vllm_moe
.
gate
.
linear_weights
[
"weight"
][:]
=
hf_moe
.
gate
.
weight
.
data
...
...
vllm/model_executor/layers/fused_moe.py
View file @
f0d4e145
...
...
@@ -4,6 +4,7 @@ import triton
import
triton.language
as
tl
from
vllm._C
import
ops
from
vllm.utils
import
is_hip
@
triton
.
jit
...
...
@@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
dict
):
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -210,12 +210,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
False
):
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...
...
@@ -223,15 +226,19 @@ def fused_moe(hidden_states: torch.Tensor,
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Incompatible dimensions"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
...
...
@@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor,
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
is_hip
():
# The MoE kernels are not yet supported on ROCm.
routing_weights
=
torch
.
softmax
(
gating_output
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
routing_weights
,
topk
,
dim
=-
1
)
else
:
import
vllm._moe_C
as
moe_kernels
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernels
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
...
...
vllm/model_executor/models/deepseek.py
View file @
f0d4e145
...
...
@@ -25,7 +25,6 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
...
...
@@ -155,20 +154,12 @@ class DeepseekMoE(nn.Module):
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
if
self
.
config
.
norm_topk_prob
:
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
self
.
w2
,
routing_weights
,
selected_experts
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
config
.
norm_topk_prob
,
inplace
=
True
)
if
self
.
config
.
n_shared_experts
is
not
None
:
...
...
vllm/model_executor/models/mixtral.py
View file @
f0d4e145
...
...
@@ -24,8 +24,6 @@
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
...
...
@@ -128,18 +126,12 @@ class MixtralMoE(nn.Module):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
routing_weights
,
selected_experts
,
router_logits
,
self
.
top_k
,
renormalize
=
True
,
inplace
=
True
)
if
self
.
tp_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment