Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
f0d4e145
Unverified
Commit
f0d4e145
authored
Feb 05, 2024
by
Woosuk Kwon
Committed by
GitHub
Feb 05, 2024
Browse files
Add fused top-K softmax kernel for MoE (#2769)
parent
2ccee3de
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
591 additions
and
50 deletions
+591
-50
csrc/moe/moe_ops.cpp
csrc/moe/moe_ops.cpp
+7
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+9
-0
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+499
-0
csrc/pybind.cpp
csrc/pybind.cpp
+1
-1
setup.py
setup.py
+11
-0
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+10
-16
vllm/model_executor/layers/fused_moe.py
vllm/model_executor/layers/fused_moe.py
+48
-10
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+3
-12
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+3
-11
No files found.
csrc/moe/moe_ops.cpp
0 → 100644
View file @
f0d4e145
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
}
csrc/moe/moe_ops.h
0 → 100644
View file @
f0d4e145
#pragma once
#include <torch/extension.h>
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
csrc/moe/topk_softmax_kernels.cu
0 → 100644
View file @
f0d4e145
/*
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
* Copyright (c) 2024, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
namespace
vllm
{
namespace
moe
{
static
constexpr
int
WARP_SIZE
=
32
;
/// Aligned array type
template
<
typename
T
,
/// Number of elements in the array
int
N
,
/// Alignment requirement in bytes
int
Alignment
=
sizeof
(
T
)
*
N
>
class
alignas
(
Alignment
)
AlignedArray
{
float
data
[
N
];
};
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template
<
int
TPB
>
__launch_bounds__
(
TPB
)
__global__
void
moeSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_cols
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmpStorage
;
__shared__
float
normalizing_factor
;
__shared__
float
float_max
;
const
int
thread_row_offset
=
blockIdx
.
x
*
num_cols
;
cub
::
Sum
sum
;
float
threadData
(
-
FLT_MAX
);
// Don't touch finished rows.
if
((
finished
!=
nullptr
)
&&
finished
[
blockIdx
.
x
])
{
return
;
}
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
threadData
=
max
(
static_cast
<
float
>
(
input
[
idx
]),
threadData
);
}
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
cub
::
Max
());
if
(
threadIdx
.
x
==
0
)
{
float_max
=
maxElem
;
}
__syncthreads
();
threadData
=
0
;
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
threadData
+=
exp
((
static_cast
<
float
>
(
input
[
idx
])
-
float_max
));
}
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
sum
);
if
(
threadIdx
.
x
==
0
)
{
normalizing_factor
=
1.
f
/
Z
;
}
__syncthreads
();
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
const
float
val
=
exp
((
static_cast
<
float
>
(
input
[
idx
])
-
float_max
))
*
normalizing_factor
;
output
[
idx
]
=
val
;
}
}
template
<
int
TPB
>
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
int
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
cub_kvp
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmpStorage
;
cub_kvp
thread_kvp
;
cub
::
ArgMax
arg_max
;
const
int
num_rows
=
gridDim
.
x
;
const
int
block_row
=
blockIdx
.
x
;
const
bool
row_is_active
=
finished
?
!
finished
[
block_row
]
:
true
;
const
int
thread_read_offset
=
blockIdx
.
x
*
num_experts
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
thread_kvp
.
key
=
0
;
thread_kvp
.
value
=
-
1.
f
;
// This is OK because inputs are probabilities
cub_kvp
inp_kvp
;
for
(
int
expert
=
threadIdx
.
x
;
expert
<
num_experts
;
expert
+=
TPB
)
{
const
int
idx
=
thread_read_offset
+
expert
;
inp_kvp
.
key
=
expert
;
inp_kvp
.
value
=
inputs_after_softmax
[
idx
];
for
(
int
prior_k
=
0
;
prior_k
<
k_idx
;
++
prior_k
)
{
const
int
prior_winning_expert
=
indices
[
k
*
block_row
+
prior_k
];
if
(
prior_winning_expert
==
expert
)
{
inp_kvp
=
thread_kvp
;
}
}
thread_kvp
=
arg_max
(
inp_kvp
,
thread_kvp
);
}
const
cub_kvp
result_kvp
=
BlockReduce
(
tmpStorage
).
Reduce
(
thread_kvp
,
arg_max
);
if
(
threadIdx
.
x
==
0
)
{
// Ignore experts the node isn't responsible for with expert parallelism
const
int
expert
=
result_kvp
.
key
;
const
bool
node_uses_expert
=
expert
>=
start_expert
&&
expert
<
end_expert
;
const
bool
should_process_row
=
row_is_active
&&
node_uses_expert
;
const
int
idx
=
k
*
block_row
+
k_idx
;
output
[
idx
]
=
result_kvp
.
value
;
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
num_experts
;
assert
(
indices
[
idx
]
>=
0
);
source_rows
[
idx
]
=
k_idx
*
num_rows
+
block_row
;
}
__syncthreads
();
}
}
// ====================== TopK softmax things ===============================
/*
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
are a small power of 2. This allows us to cleanly share the rows among the threads in
a single warp and eliminate communication between warps (so no need to use shared mem).
It fuses the softmax, max and argmax into a single kernel.
Limitations:
1) This implementation is intended for when the number of experts is a small power of 2.
2) This implementation assumes k is small, but will work for any k.
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
int
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert
(
VPT
==
(
VPT
&
-
VPT
),
"VPT must be power of 2"
);
static_assert
(
NUM_EXPERTS
==
(
NUM_EXPERTS
&
-
NUM_EXPERTS
),
"NUM_EXPERTS must be power of 2"
);
static_assert
(
BYTES_PER_LDG
==
(
BYTES_PER_LDG
&
-
BYTES_PER_LDG
),
"BYTES_PER_LDG must be power of 2"
);
static_assert
(
BYTES_PER_LDG
<=
16
,
"BYTES_PER_LDG must be leq 16"
);
// Number of bytes each thread pulls in per load
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static
constexpr
int
ELTS_PER_ROW
=
NUM_EXPERTS
;
static
constexpr
int
THREADS_PER_ROW
=
ELTS_PER_ROW
/
VPT
;
static
constexpr
int
LDG_PER_THREAD
=
VPT
/
ELTS_PER_LDG
;
// Restrictions based on previous section.
static_assert
(
VPT
%
ELTS_PER_LDG
==
0
,
"The elements per thread must be a multiple of the elements per ldg"
);
static_assert
(
WARP_SIZE
%
THREADS_PER_ROW
==
0
,
"The threads per row must cleanly divide the threads per warp"
);
static_assert
(
THREADS_PER_ROW
==
(
THREADS_PER_ROW
&
-
THREADS_PER_ROW
),
"THREADS_PER_ROW must be power of 2"
);
static_assert
(
THREADS_PER_ROW
<=
WARP_SIZE
,
"THREADS_PER_ROW can be at most warp size"
);
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static
constexpr
int
ELTS_PER_WARP
=
WARP_SIZE
*
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
ELTS_PER_WARP
/
ELTS_PER_ROW
;
static
constexpr
int
ROWS_PER_CTA
=
WARPS_PER_CTA
*
ROWS_PER_WARP
;
// Restrictions for previous section.
static_assert
(
ELTS_PER_WARP
%
ELTS_PER_ROW
==
0
,
"The elts per row must cleanly divide the total elt per warp"
);
// ===================== From this point, we finally start computing run-time variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
// This, each block processes a chunk of rows. We start by computing the start row for each block.
const
int
cta_base_row
=
blockIdx
.
x
*
ROWS_PER_CTA
;
// Now, using the base row per thread block, we compute the base row per warp.
const
int
warp_base_row
=
cta_base_row
+
threadIdx
.
y
*
ROWS_PER_WARP
;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const
int
thread_row_in_warp
=
threadIdx
.
x
/
THREADS_PER_ROW
;
const
int
thread_row
=
warp_base_row
+
thread_row_in_warp
;
// Threads with indices out of bounds should early exit here.
if
(
thread_row
>=
num_rows
)
{
return
;
}
const
bool
row_is_active
=
finished
?
!
finished
[
thread_row
]
:
true
;
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const
float
*
thread_row_ptr
=
input
+
thread_row
*
ELTS_PER_ROW
;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const
int
thread_group_idx
=
threadIdx
.
x
%
THREADS_PER_ROW
;
const
int
first_elt_read_by_thread
=
thread_group_idx
*
ELTS_PER_LDG
;
const
float
*
thread_read_ptr
=
thread_row_ptr
+
first_elt_read_by_thread
;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using
AccessType
=
AlignedArray
<
float
,
ELTS_PER_LDG
>
;
// Finally, we pull in the data from global mem
float
row_chunk
[
VPT
];
AccessType
*
row_chunk_vec_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
row_chunk
);
const
AccessType
*
vec_thread_read_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
thread_read_ptr
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
row_chunk_vec_ptr
[
ii
]
=
vec_thread_read_ptr
[
ii
*
THREADS_PER_ROW
];
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
float
thread_max
=
row_chunk
[
0
];
#pragma unroll
for
(
int
ii
=
1
;
ii
<
VPT
;
++
ii
)
{
thread_max
=
max
(
thread_max
,
row_chunk
[
ii
]);
}
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
thread_max
=
max
(
thread_max
,
__shfl_xor_sync
(
0xFFFFFFFF
,
thread_max
,
mask
,
THREADS_PER_ROW
));
}
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
float
row_sum
=
0
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
expf
(
row_chunk
[
ii
]
-
thread_max
);
row_sum
+=
row_chunk
[
ii
];
}
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
row_sum
+=
__shfl_xor_sync
(
0xFFFFFFFF
,
row_sum
,
mask
,
THREADS_PER_ROW
);
}
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const
float
reciprocal_row_sum
=
1.
f
/
row_sum
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
row_chunk
[
ii
]
*
reciprocal_row_sum
;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
int
start_col
=
first_elt_read_by_thread
;
static
constexpr
int
COLS_PER_GROUP_LDG
=
ELTS_PER_LDG
*
THREADS_PER_ROW
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
// First, each thread does the local argmax
float
max_val
=
row_chunk
[
0
];
int
expert
=
start_col
;
#pragma unroll
for
(
int
ldg
=
0
,
col
=
start_col
;
ldg
<
LDG_PER_THREAD
;
++
ldg
,
col
+=
COLS_PER_GROUP_LDG
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ELTS_PER_LDG
;
++
ii
)
{
float
val
=
row_chunk
[
ldg
*
ELTS_PER_LDG
+
ii
];
// No check on the experts here since columns with the smallest index are processed first and only
// updated if > (not >=)
if
(
val
>
max_val
)
{
max_val
=
val
;
expert
=
col
+
ii
;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
// then blank out their max with -inf and the warp can run more iterations...
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
float
other_max
=
__shfl_xor_sync
(
0xFFFFFFFF
,
max_val
,
mask
,
THREADS_PER_ROW
);
int
other_expert
=
__shfl_xor_sync
(
0xFFFFFFFF
,
expert
,
mask
,
THREADS_PER_ROW
);
// We want lower indices to "win" in every thread so we break ties this way
if
(
other_max
>
max_val
||
(
other_max
==
max_val
&&
other_expert
<
expert
))
{
max_val
=
other_max
;
expert
=
other_expert
;
}
}
// Write the max for this k iteration to global memory.
if
(
thread_group_idx
==
0
)
{
// Add a guard to ignore experts not included by this node
const
bool
node_uses_expert
=
expert
>=
start_expert
&&
expert
<
end_expert
;
const
bool
should_process_row
=
row_is_active
&&
node_uses_expert
;
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
// single) thread per row of the input/output matrices.
const
int
idx
=
k
*
thread_row
+
k_idx
;
output
[
idx
]
=
max_val
;
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
NUM_EXPERTS
;
source_rows
[
idx
]
=
k_idx
*
num_rows
+
thread_row
;
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
if
(
k_idx
+
1
<
k
)
{
const
int
ldg_group_for_expert
=
expert
/
COLS_PER_GROUP_LDG
;
const
int
thread_to_clear_in_group
=
(
expert
/
ELTS_PER_LDG
)
%
THREADS_PER_ROW
;
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
if
(
thread_group_idx
==
thread_to_clear_in_group
)
{
const
int
offset_for_expert
=
expert
%
ELTS_PER_LDG
;
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk
[
ldg_group_for_expert
*
ELTS_PER_LDG
+
offset_for_expert
]
=
-
10000.
f
;
}
}
}
}
namespace
detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template
<
int
EXPERTS
,
int
BYTES_PER_LDG
>
struct
TopkConstants
{
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static_assert
(
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
||
EXPERTS
%
(
ELTS_PER_LDG
*
WARP_SIZE
)
==
0
,
""
);
static
constexpr
int
VECs_PER_THREAD
=
std
::
max
(
1
,
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE
));
static
constexpr
int
VPT
=
VECs_PER_THREAD
*
ELTS_PER_LDG
;
static
constexpr
int
THREADS_PER_ROW
=
EXPERTS
/
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
WARP_SIZE
/
THREADS_PER_ROW
;
};
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
int
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
static
constexpr
int
BYTES_PER_LDG
=
std
::
min
(
MAX_BYTES_PER_LDG
,
sizeof
(
float
)
*
EXPERTS
);
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
>
;
static
constexpr
int
VPT
=
Constants
::
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
Constants
::
ROWS_PER_WARP
;
const
int
num_warps
=
(
num_rows
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
;
const
int
num_blocks
=
(
num_warps
+
WARPS_PER_TB
-
1
)
/
WARPS_PER_TB
;
dim3
block_dim
(
WARP_SIZE
,
WARPS_PER_TB
);
topkGatingSoftmax
<
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
,
finished
,
output
,
num_rows
,
indices
,
source_row
,
k
,
start_expert
,
end_expert
);
}
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indicies, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
void
topkGatingSoftmaxKernelLauncher
(
const
float
*
gating_output
,
float
*
topk_weights
,
int
*
topk_indicies
,
int
*
token_expert_indices
,
float
*
softmax_workspace
,
const
int
num_tokens
,
const
int
num_experts
,
const
int
topk
,
cudaStream_t
stream
)
{
static
constexpr
int
WARPS_PER_TB
=
4
;
switch
(
num_experts
)
{
case
1
:
LAUNCH_SOFTMAX
(
1
,
WARPS_PER_TB
);
break
;
case
2
:
LAUNCH_SOFTMAX
(
2
,
WARPS_PER_TB
);
break
;
case
4
:
LAUNCH_SOFTMAX
(
4
,
WARPS_PER_TB
);
break
;
case
8
:
LAUNCH_SOFTMAX
(
8
,
WARPS_PER_TB
);
break
;
case
16
:
LAUNCH_SOFTMAX
(
16
,
WARPS_PER_TB
);
break
;
case
32
:
LAUNCH_SOFTMAX
(
32
,
WARPS_PER_TB
);
break
;
case
64
:
LAUNCH_SOFTMAX
(
64
,
WARPS_PER_TB
);
break
;
case
128
:
LAUNCH_SOFTMAX
(
128
,
WARPS_PER_TB
);
break
;
case
256
:
LAUNCH_SOFTMAX
(
256
,
WARPS_PER_TB
);
break
;
default:
{
TORCH_CHECK
(
softmax_workspace
!=
nullptr
,
"softmax_workspace must be provided for num_experts that are not a power of 2."
);
static
constexpr
int
TPB
=
256
;
moeSoftmax
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
);
moeTopK
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
softmax_workspace
,
nullptr
,
topk_weights
,
topk_indicies
,
token_expert_indices
,
num_experts
,
topk
,
0
,
num_experts
);
}
}
}
}
// namespace moe
}
// namespace vllm
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
// [num_tokens, topk]
torch
::
Tensor
&
topk_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
token_expert_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
gating_output
)
// [num_tokens, num_experts]
{
const
int
num_experts
=
gating_output
.
size
(
-
1
);
const
int
num_tokens
=
gating_output
.
numel
()
/
num_experts
;
const
int
topk
=
topk_weights
.
size
(
-
1
);
const
bool
is_pow_2
=
(
num_experts
!=
0
)
&&
((
num_experts
&
(
num_experts
-
1
))
==
0
);
const
bool
needs_workspace
=
!
is_pow_2
||
num_experts
>
256
;
const
int64_t
workspace_size
=
needs_workspace
?
num_tokens
*
num_experts
:
0
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
csrc/pybind.cpp
View file @
f0d4e145
...
...
@@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
#ifndef USE_ROCM
// Quantization ops
#ifndef USE_ROCM
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
...
...
setup.py
View file @
f0d4e145
...
...
@@ -339,6 +339,17 @@ if _is_cuda():
vllm_extension_sources
.
append
(
"csrc/quantization/awq/gemm_kernels.cu"
)
vllm_extension_sources
.
append
(
"csrc/custom_all_reduce.cu"
)
# Add MoE kernels.
ext_modules
.
append
(
CUDAExtension
(
name
=
"vllm._moe_C"
,
sources
=
glob
(
"csrc/moe/*.cu"
)
+
glob
(
"csrc/moe/*.cpp"
),
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
))
if
not
_is_neuron
():
vllm_extension
=
CUDAExtension
(
name
=
"vllm._C"
,
...
...
tests/kernels/test_moe.py
View file @
f0d4e145
...
...
@@ -2,10 +2,8 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import
pytest
import
torch
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
...
...
@@ -14,22 +12,21 @@ from vllm.model_executor.layers.activation import SiluAndMul
from
vllm.model_executor.models.mixtral
import
MixtralMoE
def
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk
_ids
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk_ids
.
shape
[
1
],
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
topk_ids
=
topk_ids
.
view
(
-
1
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
)).
sum
(
dim
=
1
)
topk_weight
.
view
(
B
,
-
1
,
1
)
.
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
...
...
@@ -51,11 +48,8 @@ def test_fused_moe(
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
'cuda'
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
'cuda'
,
dtype
=
dtype
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
...
...
@@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
tp_size
=
1
,
)
)
.
cuda
()
# Load the weights
vllm_moe
.
gate
.
linear_weights
[
"weight"
][:]
=
hf_moe
.
gate
.
weight
.
data
...
...
vllm/model_executor/layers/fused_moe.py
View file @
f0d4e145
...
...
@@ -4,6 +4,7 @@ import triton
import
triton.language
as
tl
from
vllm._C
import
ops
from
vllm.utils
import
is_hip
@
triton
.
jit
...
...
@@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
dict
):
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -210,12 +210,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
=
False
):
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...
...
@@ -223,15 +226,19 @@ def fused_moe(hidden_states: torch.Tensor,
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Incompatible dimensions"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
...
...
@@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor,
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
is_hip
():
# The MoE kernels are not yet supported on ROCm.
routing_weights
=
torch
.
softmax
(
gating_output
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
routing_weights
,
topk
,
dim
=-
1
)
else
:
import
vllm._moe_C
as
moe_kernels
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernels
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
...
...
vllm/model_executor/models/deepseek.py
View file @
f0d4e145
...
...
@@ -25,7 +25,6 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
...
...
@@ -155,20 +154,12 @@ class DeepseekMoE(nn.Module):
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
if
self
.
config
.
norm_topk_prob
:
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
self
.
w2
,
routing_weights
,
selected_experts
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
config
.
norm_topk_prob
,
inplace
=
True
)
if
self
.
config
.
n_shared_experts
is
not
None
:
...
...
vllm/model_executor/models/mixtral.py
View file @
f0d4e145
...
...
@@ -24,8 +24,6 @@
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
...
...
@@ -128,18 +126,12 @@ class MixtralMoE(nn.Module):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
routing_weights
,
selected_experts
,
router_logits
,
self
.
top_k
,
renormalize
=
True
,
inplace
=
True
)
if
self
.
tp_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment