Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
f0d4e145
Unverified
Commit
f0d4e145
authored
Feb 05, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 05, 2024
Browse files
Add fused top-K softmax kernel for MoE (#2769)
parent
2ccee3de
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
591 additions
and
50 deletions
+591
-50
csrc/moe/moe_ops.cpp
csrc/moe/moe_ops.cpp
+7
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+9
-0
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+499
-0
csrc/pybind.cpp
csrc/pybind.cpp
+1
-1
setup.py
setup.py
+11
-0
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+10
-16
vllm/model_executor/layers/fused_moe.py
vllm/model_executor/layers/fused_moe.py
+48
-10
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+3
-12
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+3
-11
No files found.
csrc/moe/moe_ops.cpp
0 → 100644
View file @
f0d4e145
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
}
csrc/moe/moe_ops.h
0 → 100644
View file @
f0d4e145
#pragma once
#include <torch/extension.h>
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
csrc/moe/topk_softmax_kernels.cu
0 → 100644
View file @
f0d4e145
This diff is collapsed.
Click to expand it.
csrc/pybind.cpp
View file @
f0d4e145
...
...
@@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
#ifndef USE_ROCM
// Quantization ops
#ifndef USE_ROCM
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
...
...
setup.py
View file @
f0d4e145
...
...
@@ -339,6 +339,17 @@ if _is_cuda():
vllm_extension_sources
.
append
(
"csrc/quantization/awq/gemm_kernels.cu"
)
vllm_extension_sources
.
append
(
"csrc/custom_all_reduce.cu"
)
# Add MoE kernels.
ext_modules
.
append
(
CUDAExtension
(
name
=
"vllm._moe_C"
,
sources
=
glob
(
"csrc/moe/*.cu"
)
+
glob
(
"csrc/moe/*.cpp"
),
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
))
if
not
_is_neuron
():
vllm_extension
=
CUDAExtension
(
name
=
"vllm._C"
,
...
...
tests/kernels/test_moe.py
View file @
f0d4e145
...
...
@@ -2,10 +2,8 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import
pytest
import
torch
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
...
...
@@ -14,22 +12,21 @@ from vllm.model_executor.layers.activation import SiluAndMul
from
vllm.model_executor.models.mixtral
import
MixtralMoE
def
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk
_ids
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk_ids
.
shape
[
1
],
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
topk_ids
=
topk_ids
.
view
(
-
1
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
)).
sum
(
dim
=
1
)
topk_weight
.
view
(
B
,
-
1
,
1
)
.
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
...
...
@@ -51,11 +48,8 @@ def test_fused_moe(
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
'cuda'
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
'cuda'
,
dtype
=
dtype
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
...
...
@@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
tp_size
=
1
,
)
)
.
cuda
()
# Load the weights
vllm_moe
.
gate
.
linear_weights
[
"weight"
][:]
=
hf_moe
.
gate
.
weight
.
data
...
...
vllm/model_executor/layers/fused_moe.py
View file @
f0d4e145
...
...
@@ -4,6 +4,7 @@ import triton
import
triton.language
as
tl
from
vllm._C
import
ops
from
vllm.utils
import
is_hip
@
triton
.
jit
...
...
@@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
dict
):
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -210,12 +210,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
False
):
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...
...
@@ -223,15 +226,19 @@ def fused_moe(hidden_states: torch.Tensor,
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Incompatible dimensions"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
...
...
@@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor,
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
is_hip
():
# The MoE kernels are not yet supported on ROCm.
routing_weights
=
torch
.
softmax
(
gating_output
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
routing_weights
,
topk
,
dim
=-
1
)
else
:
import
vllm._moe_C
as
moe_kernels
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernels
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
...
...
vllm/model_executor/models/deepseek.py
View file @
f0d4e145
...
...
@@ -25,7 +25,6 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
...
...
@@ -155,20 +154,12 @@ class DeepseekMoE(nn.Module):
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
if
self
.
config
.
norm_topk_prob
:
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
self
.
w2
,
routing_weights
,
selected_experts
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
config
.
norm_topk_prob
,
inplace
=
True
)
if
self
.
config
.
n_shared_experts
is
not
None
:
...
...
vllm/model_executor/models/mixtral.py
View file @
f0d4e145
...
...
@@ -24,8 +24,6 @@
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
...
...
@@ -128,18 +126,12 @@ class MixtralMoE(nn.Module):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
routing_weights
,
selected_experts
,
router_logits
,
self
.
top_k
,
renormalize
=
True
,
inplace
=
True
)
if
self
.
tp_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment