Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0b229519
Commit
0b229519
authored
May 27, 2025
by
王敏
Browse files
[feat]适配sgl moe_fused_gate kernel
parent
1150b65c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
626 additions
and
16 deletions
+626
-16
CMakeLists.txt
CMakeLists.txt
+2
-1
csrc/moe/moe_fused_gate.cu
csrc/moe/moe_fused_gate.cu
+539
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+10
-1
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+6
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+28
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+34
-13
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+2
-1
No files found.
CMakeLists.txt
View file @
0b229519
...
@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
...
@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set
(
VLLM_MOE_EXT_SRC
set
(
VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu"
)
"csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/moe_fused_gate.cu"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
list
(
APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
)
list
(
APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
)
...
...
csrc/moe/moe_fused_gate.cu
0 → 100644
View file @
0b229519
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include "../cuda_compat.h"
// #include <cutlass/array.h>
// #include <cutlass/cutlass.h>
// #include <cutlass/numeric_types.h>
#include <stdio.h>
#include <torch/all.h>
#include <cfloat>
#include <type_traits>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
/// Aligned array type
template
<
typename
T
,
/// Number of elements in the array
int
N
,
/// Alignment requirement in bytes
int
Alignment
=
sizeof
(
T
)
*
N
>
class
alignas
(
Alignment
)
AlignedArray
{
T
data
[
N
];
public:
__device__
T
&
operator
[](
int
index
)
{
return
data
[
index
];
}
__device__
const
T
&
operator
[](
int
index
)
const
{
return
data
[
index
];
}
};
// template <typename T, int N>
// using AlignedArray = cutlass::AlignedArray<T, N>;
// using bfloat16_t = cutlass::bfloat16_t;
// using float16_t = cutlass::half_t;
using
float32_t
=
float
;
// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator
// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)"
template
<
typename
T
>
__device__
inline
bool
cmp_gt
(
const
T
&
a
,
const
T
&
b
)
{
if
constexpr
(
std
::
is_same
<
T
,
at
::
Half
>::
value
)
{
// at::Half (or float16_t in our native case) causes ambiguity, so we cast to float.
return
static_cast
<
float
>
(
a
)
>
static_cast
<
float
>
(
b
);
}
else
{
// For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected.
return
a
>
b
;
}
}
template
<
typename
T
>
__device__
inline
bool
cmp_eq
(
const
T
&
a
,
const
T
&
b
)
{
if
constexpr
(
std
::
is_same
<
T
,
at
::
Half
>::
value
)
{
return
static_cast
<
float
>
(
a
)
==
static_cast
<
float
>
(
b
);
}
else
{
return
a
==
b
;
}
}
// Fixed constants common to both dynamic and static template versions:
//static constexpr int WARP_SIZE = 32;
static
constexpr
int
WARPS_PER_CTA
=
6
;
static
constexpr
int
MAX_VPT
=
32
;
// maximum VPT we support, > params.VPT = num_expert / num_expert_group
// Create an alias for Array using AlignedArray
template
<
typename
T
,
int
N
>
using
Array
=
AlignedArray
<
T
,
N
>
;
// QQ: NOTE expression must have a constant value, this has to be > params.VPT
template
<
typename
T
>
using
AccessType
=
AlignedArray
<
T
,
MAX_VPT
>
;
template
<
typename
T
,
typename
Params
>
__device__
void
moe_fused_gate_impl
(
void
*
input
,
void
*
bias
,
float
*
output_ptr
,
int32_t
*
indices_ptr
,
int64_t
num_rows
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
,
Params
params
)
{
int
tidx
=
threadIdx
.
x
;
int64_t
thread_row
=
blockIdx
.
x
*
params
.
ROWS_PER_CTA
+
threadIdx
.
y
*
params
.
ROWS_PER_WARP
+
tidx
/
params
.
THREADS_PER_ROW
;
if
(
thread_row
>=
num_rows
)
{
return
;
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t
topk_excluding_share_expert_fusion
=
topk
-
(
n_share_experts_fusion
>
0
?
1
:
0
);
// Cast pointers to type T:
auto
*
input_ptr
=
reinterpret_cast
<
T
*>
(
input
);
auto
*
bias_ptr
=
reinterpret_cast
<
T
*>
(
bias
);
auto
*
thread_row_ptr
=
input_ptr
+
thread_row
*
params
.
NUM_EXPERTS
;
int
thread_group_idx
=
tidx
%
params
.
THREADS_PER_ROW
;
int
first_elt_read_by_thread
=
thread_group_idx
*
params
.
VPT
;
// Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to
// AccessType.
T
*
thread_read_ptr
=
thread_row_ptr
+
first_elt_read_by_thread
;
Array
<
T
,
MAX_VPT
>
row_chunk
;
// T row_chunk[params.VPT];
AccessType
<
T
>
const
*
vec_thread_read_ptr
=
reinterpret_cast
<
AccessType
<
T
>
const
*>
(
thread_read_ptr
);
T
*
bias_thread_read_ptr
=
bias_ptr
+
first_elt_read_by_thread
;
Array
<
T
,
MAX_VPT
>
bias_chunk
;
// T bias_chunk[params.VPT];
AccessType
<
T
>
const
*
vec_bias_thread_read_ptr
=
reinterpret_cast
<
AccessType
<
T
>
const
*>
(
bias_thread_read_ptr
);
//AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
//AccessType<T>* bias_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&bias_chunk);
// QQ NOTE: doing the follow will be slower than loop assign and more importantly
// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT
// AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
vec_thread_read_ptr
[
0
][
ii
];
bias_chunk
[
ii
]
=
vec_bias_thread_read_ptr
[
0
][
ii
];
}
/*row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
bias_chunk_vec_ptr[0] = vec_bias_thread_read_ptr[0];*/
__syncthreads
();
////////////////////// Sigmoid //////////////////////
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
+
expf
(
-
float
(
row_chunk
[
ii
]))));
}
__syncthreads
();
////////////////////// Add Bias //////////////////////
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
bias_chunk
[
ii
]
=
row_chunk
[
ii
]
+
bias_chunk
[
ii
];
}
////////////////////// Exclude Groups //////////////////////
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
params
.
THREADS_PER_ROW
-
topk_group
;
++
k_idx
)
{
// QQ NOTE Here params.THREADS_PER_ROW = num_expert_group
int
expert
=
first_elt_read_by_thread
;
// local argmax
T
max_val
=
static_cast
<
T
>
(
-
FLT_MAX
);
T
max_val_second
=
static_cast
<
T
>
(
-
FLT_MAX
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
T
val
=
bias_chunk
[
ii
];
if
(
cmp_gt
(
val
,
max_val
))
{
max_val_second
=
max_val
;
max_val
=
val
;
}
else
if
(
cmp_gt
(
val
,
max_val_second
))
{
max_val_second
=
val
;
}
}
// QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight
// to select expert groups
T
max_sum
=
max_val
+
max_val_second
;
// argmin reduce
#pragma unroll
for
(
int
mask
=
params
.
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
T
other_max_sum
=
static_cast
<
T
>
(
VLLM_SHFL_XOR_SYNC_WIDTH
(
static_cast
<
float
>
(
max_sum
),
mask
,
params
.
THREADS_PER_ROW
));
int
other_expert
=
VLLM_SHFL_XOR_SYNC_WIDTH
(
expert
,
mask
,
params
.
THREADS_PER_ROW
);
// higher indices win
if
(
cmp_gt
(
max_sum
,
other_max_sum
)
||
(
cmp_eq
(
other_max_sum
,
max_sum
)
&&
other_expert
>
expert
))
{
max_sum
=
other_max_sum
;
expert
=
other_expert
;
}
}
// clear the max value in the thread
if
(
k_idx
<
params
.
THREADS_PER_ROW
-
topk_group
)
{
int
const
thread_to_clear_in_group
=
expert
/
params
.
VPT
;
if
(
thread_group_idx
==
thread_to_clear_in_group
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
bias_chunk
[
ii
]
=
static_cast
<
T
>
(
FLT_MAX
);
}
}
}
}
__syncthreads
();
////////////////////// Topk //////////////////////
float
output_sum
=
0.0
f
;
for
(
int
k_idx
=
0
;
k_idx
<
topk_excluding_share_expert_fusion
;
++
k_idx
)
{
// local argmax
T
max_val
=
bias_chunk
[
0
];
int
expert
=
first_elt_read_by_thread
;
if
(
!
cmp_eq
(
max_val
,
static_cast
<
T
>
(
FLT_MAX
)))
{
#pragma unroll
for
(
int
ii
=
1
;
ii
<
params
.
VPT
;
++
ii
)
{
T
val
=
bias_chunk
[
ii
];
if
(
cmp_gt
(
val
,
max_val
))
{
max_val
=
val
;
expert
=
first_elt_read_by_thread
+
ii
;
}
}
}
else
{
max_val
=
static_cast
<
T
>
(
-
FLT_MAX
);
}
// argmax reduce
#pragma unroll
for
(
int
mask
=
params
.
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
T
other_max
=
static_cast
<
T
>
(
VLLM_SHFL_XOR_SYNC_WIDTH
(
static_cast
<
float
>
(
max_val
),
mask
,
params
.
THREADS_PER_ROW
));
int
other_expert
=
VLLM_SHFL_XOR_SYNC_WIDTH
(
expert
,
mask
,
params
.
THREADS_PER_ROW
);
// lower indices to win
if
(
cmp_gt
(
other_max
,
max_val
)
||
(
cmp_eq
(
other_max
,
max_val
)
&&
other_expert
<
expert
))
{
max_val
=
other_max
;
expert
=
other_expert
;
}
}
int
thread_to_clear_in_group
=
expert
/
params
.
VPT
;
int64_t
idx
=
topk
*
thread_row
+
k_idx
;
if
(
thread_group_idx
==
thread_to_clear_in_group
)
{
int
expert_to_clear_in_thread
=
expert
%
params
.
VPT
;
// clear the max value in the thread
bias_chunk
[
expert_to_clear_in_thread
]
=
static_cast
<
T
>
(
-
FLT_MAX
);
// store output
output_ptr
[
idx
]
=
static_cast
<
float
>
(
row_chunk
[
expert_to_clear_in_thread
]);
indices_ptr
[
idx
]
=
static_cast
<
int32_t
>
(
expert
);
}
// accumulate sum for all elements
if
(
thread_group_idx
==
0
)
{
output_sum
+=
output_ptr
[
idx
];
}
__syncthreads
();
}
if
(
thread_group_idx
==
0
&&
n_share_experts_fusion
>
0
)
{
int64_t
last_idx
=
topk
*
thread_row
+
topk_excluding_share_expert_fusion
;
// Use round-robin to select expert
int64_t
expert_offset
=
thread_row
%
n_share_experts_fusion
;
indices_ptr
[
last_idx
]
=
static_cast
<
int32_t
>
(
params
.
NUM_EXPERTS
+
expert_offset
);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr
[
last_idx
]
=
output_sum
/
routed_scaling_factor
;
}
__syncthreads
();
////////////////////// Rescale Output //////////////////////
if
(
thread_group_idx
==
0
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
topk
;
++
ii
)
{
int64_t
const
idx
=
topk
*
thread_row
+
ii
;
output_ptr
[
idx
]
=
output_ptr
[
idx
]
/
output_sum
;
}
}
}
//------------------------------------------------------------------------------
// Templated Kernel Version (using compile-time constants)
//------------------------------------------------------------------------------
template
<
int
VPT_
,
int
NUM_EXPERTS_
,
int
THREADS_PER_ROW_
,
int
ROWS_PER_WARP_
,
int
ROWS_PER_CTA_
,
int
WARPS_PER_CTA_
>
struct
KernelParams
{
static
constexpr
int
VPT
=
VPT_
;
static
constexpr
int
NUM_EXPERTS
=
NUM_EXPERTS_
;
static
constexpr
int
THREADS_PER_ROW
=
THREADS_PER_ROW_
;
static
constexpr
int
ROWS_PER_WARP
=
ROWS_PER_WARP_
;
static
constexpr
int
ROWS_PER_CTA
=
ROWS_PER_CTA_
;
static
constexpr
int
WARPS_PER_CTA
=
WARPS_PER_CTA_
;
};
template
<
typename
T
,
int
VPT
,
int
NUM_EXPERTS
,
int
THREADS_PER_ROW
,
int
ROWS_PER_WARP
,
int
ROWS_PER_CTA
,
int
WARPS_PER_CTA
>
__global__
void
moe_fused_gate_kernel
(
void
*
input
,
void
*
bias
,
float
*
output_ptr
,
int32_t
*
indices_ptr
,
int64_t
num_rows
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
)
{
KernelParams
<
VPT
,
NUM_EXPERTS
,
THREADS_PER_ROW
,
ROWS_PER_WARP
,
ROWS_PER_CTA
,
WARPS_PER_CTA
>
params
;
moe_fused_gate_impl
<
T
>
(
input
,
bias
,
output_ptr
,
indices_ptr
,
num_rows
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
,
params
);
}
// Macro to compute compile-time constants and launch the kernel.
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
do { \
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */
\
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
<<<num_blocks, block_dim, 0, stream>>>( \
input.data_ptr(), \
bias.data_ptr(), \
output.data_ptr<float>(), \
indices.data_ptr<int32_t>(), \
num_rows, \
topk_group, \
topk, \
n_share_experts_fusion, \
routed_scaling_factor); \
dispatched = true; \
} while (0)
//------------------------------------------------------------------------------
// Dynamic Kernel Version (parameters computed at runtime)
//------------------------------------------------------------------------------
struct
KernelParamsDynamic
{
int
VPT
;
int
NUM_EXPERTS
;
int
THREADS_PER_ROW
;
int
ROWS_PER_WARP
;
int
ROWS_PER_CTA
;
int
WARPS_PER_CTA
;
};
template
<
typename
T
>
__global__
void
moe_fused_gate_kernel_dynamic
(
void
*
input
,
void
*
bias
,
float
*
output_ptr
,
int32_t
*
indices_ptr
,
int64_t
num_rows
,
int64_t
num_experts
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
)
{
KernelParamsDynamic
params
;
params
.
NUM_EXPERTS
=
num_experts
;
// e.g, for deepseek v3, this is 256
params
.
VPT
=
num_experts
/
num_expert_group
;
// e.g., for deepseek v3, this is 256 / 8 = 32
params
.
THREADS_PER_ROW
=
num_expert_group
;
// fixed as num_expert_group, e.g., for deepseek v3, this is 8
params
.
WARPS_PER_CTA
=
WARPS_PER_CTA
;
// fixed as 6
params
.
ROWS_PER_WARP
=
std
::
max
<
int64_t
>
(
1
,
WARP_SIZE
/
num_expert_group
);
// WARP_SIZE is fixed as 32
params
.
ROWS_PER_CTA
=
params
.
WARPS_PER_CTA
*
params
.
ROWS_PER_WARP
;
moe_fused_gate_impl
<
T
>
(
input
,
bias
,
output_ptr
,
indices_ptr
,
num_rows
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
,
params
);
}
//------------------------------------------------------------------------------
// Host Launcher Function
//------------------------------------------------------------------------------
std
::
vector
<
at
::
Tensor
>
moe_fused_gate
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
)
{
int64_t
num_rows
=
input
.
size
(
0
);
int32_t
num_experts
=
input
.
size
(
1
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
output
=
torch
::
empty
({
num_rows
,
topk
},
options
);
auto
indices
=
torch
::
empty
({
num_rows
,
topk
},
options
.
dtype
(
torch
::
kInt32
));
// Compute grid dimensions based on runtime value for num_expert_group.
int64_t
rows_per_warp
=
std
::
max
<
int64_t
>
(
1
,
WARP_SIZE
/
num_expert_group
);
int64_t
num_warps
=
(
num_rows
+
rows_per_warp
-
1
)
/
rows_per_warp
;
int64_t
num_blocks
=
(
num_warps
+
WARPS_PER_CTA
-
1
)
/
WARPS_PER_CTA
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
block_dim
(
WARP_SIZE
,
WARPS_PER_CTA
);
// Check 1: Ensure that num_experts is a power of 2.
TORCH_CHECK
((
num_experts
&
(
num_experts
-
1
))
==
0
,
"num_experts must be a power of 2, but got "
,
num_experts
);
// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
TORCH_CHECK
(
num_experts
%
num_expert_group
==
0
,
"num_experts must be divisible by num_expert_group, but got "
,
num_experts
,
" / "
,
num_expert_group
);
int
computed_vpt
=
num_experts
/
num_expert_group
;
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
// threads we can process.
TORCH_CHECK
(
computed_vpt
<=
MAX_VPT
,
"Per group experts: num_experts / num_expert_group = ("
,
computed_vpt
,
") exceeds the maximum supported ("
,
MAX_VPT
,
")"
);
// Dispatch to templated kernel for known compile-time configurations.
// We currently only support for:
// Case 1: 256 experts, with 8 or 16 groups.
// Case 2: 128 experts, with 4 or 8 groups.
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
bool
dispatched
=
false
;
switch
(
num_experts
)
{
case
256
:
if
(
num_expert_group
==
8
)
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
LAUNCH_MOE_GATE_CONFIG
(
__nv_bfloat16
,
256
,
8
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
LAUNCH_MOE_GATE_CONFIG
(
half
,
256
,
8
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
LAUNCH_MOE_GATE_CONFIG
(
float
,
256
,
8
);
}
else
if
(
num_expert_group
==
16
)
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
LAUNCH_MOE_GATE_CONFIG
(
__nv_bfloat16
,
256
,
16
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
LAUNCH_MOE_GATE_CONFIG
(
half
,
256
,
16
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
LAUNCH_MOE_GATE_CONFIG
(
float
,
256
,
16
);
}
break
;
case
128
:
if
(
num_expert_group
==
4
)
// VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
LAUNCH_MOE_GATE_CONFIG
(
__nv_bfloat16
,
128
,
4
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
LAUNCH_MOE_GATE_CONFIG
(
half
,
128
,
4
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
LAUNCH_MOE_GATE_CONFIG
(
float
,
128
,
4
);
}
else
if
(
num_expert_group
==
8
)
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
LAUNCH_MOE_GATE_CONFIG
(
__nv_bfloat16
,
128
,
8
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
LAUNCH_MOE_GATE_CONFIG
(
half
,
128
,
8
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
LAUNCH_MOE_GATE_CONFIG
(
float
,
128
,
8
);
}
break
;
default:
break
;
}
if
(
!
dispatched
)
{
// Fallback to the dynamic kernel if none of the supported combinations match.
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
moe_fused_gate_kernel_dynamic
<
__nv_bfloat16
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
bias
.
data_ptr
(),
output
.
data_ptr
<
float
>
(),
indices
.
data_ptr
<
int32_t
>
(),
num_rows
,
num_experts
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
moe_fused_gate_kernel_dynamic
<
half
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
bias
.
data_ptr
(),
output
.
data_ptr
<
float
>
(),
indices
.
data_ptr
<
int32_t
>
(),
num_rows
,
num_experts
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
moe_fused_gate_kernel_dynamic
<
float
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
bias
.
data_ptr
(),
output
.
data_ptr
<
float
>
(),
indices
.
data_ptr
<
int32_t
>
(),
num_rows
,
num_experts
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type for moe_fused_gate"
);
}
}
return
{
output
,
indices
};
}
csrc/moe/moe_ops.h
View file @
0b229519
...
@@ -28,4 +28,13 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
...
@@ -28,4 +28,13 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch
::
Tensor
num_tokens_post_pad
,
int64_t
top_k
,
torch
::
Tensor
num_tokens_post_pad
,
int64_t
top_k
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
BLOCK_SIZE_K
,
int64_t
bit
);
int64_t
BLOCK_SIZE_K
,
int64_t
bit
);
#endif
#endif
\ No newline at end of file
std
::
vector
<
torch
::
Tensor
>
moe_fused_gate
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
);
\ No newline at end of file
csrc/moe/torch_bindings.cpp
View file @
0b229519
...
@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"
);
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"sgl_moe_align_block_size"
,
torch
::
kCUDA
,
&
sgl_moe_align_block_size
);
m
.
impl
(
"sgl_moe_align_block_size"
,
torch
::
kCUDA
,
&
sgl_moe_align_block_size
);
m
.
def
(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
#ifndef USE_ROCM
#ifndef USE_ROCM
m
.
def
(
m
.
def
(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
...
...
vllm/_custom_ops.py
View file @
0b229519
...
@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache(
...
@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache(
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale)
# seq_lens, page_table, scale)
# return out
# return out
def
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
=
0
,
routed_scaling_factor
=
0
,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return
torch
.
ops
.
_moe_C
.
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
,
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
0b229519
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
functools
import
functools
import
json
import
json
import
os
import
os
import
math
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -1182,6 +1183,10 @@ def fused_topk(
...
@@ -1182,6 +1183,10 @@ def fused_topk(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
is_power_of_two
(
n
):
return
n
>
0
and
math
.
log2
(
n
).
is_integer
()
# This is used by the Deepseek-V2 and Deepseek-V3 model
# This is used by the Deepseek-V2 and Deepseek-V3 model
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
grouped_topk
(
def
grouped_topk
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
0b229519
...
@@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
_custom_ops
as
ops
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
from
.fused_moe
import
fused_experts
from
.fused_moe
import
fused_experts
...
@@ -222,7 +223,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -222,7 +223,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
if
hasattr
(
self
,
"routed_scaling_factor"
)
else
None
)
return
fused_experts
(
return
fused_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -436,6 +438,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -436,6 +438,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -505,6 +508,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -505,6 +508,7 @@ class FusedMoE(torch.nn.Module):
self
.
scoring_func
=
scoring_func
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
activation
=
activation
self
.
activation
=
activation
self
.
routed_scaling_factor
=
routed_scaling_factor
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
raise
ValueError
(
"Only softmax scoring function is supported for "
...
@@ -554,6 +558,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -554,6 +558,7 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
setattr
(
self
.
quant_method
,
"routed_scaling_factor"
,
self
.
routed_scaling_factor
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
...
@@ -839,23 +844,39 @@ class FusedMoE(torch.nn.Module):
...
@@ -839,23 +844,39 @@ class FusedMoE(torch.nn.Module):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
)
fused_topk
,
grouped_topk
,
is_power_of_two
)
# DeekSeekv2 uses grouped_top_k
# DeekSeekv2 uses grouped_top_k
if
use_grouped_topk
:
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
assert
num_expert_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
if
e_score_correction_bias
is
not
None
\
hidden_states
=
hidden_states
,
and
router_logits
.
shape
[
1
]
//
num_expert_group
<=
32
\
gating_output
=
router_logits
,
and
is_power_of_two
(
e_score_correction_bias
.
shape
[
0
]):
topk
=
top_k
,
renormalize
=
renormalize
,
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
num_expert_group
=
num_expert_group
,
topk_weights
,
topk_ids
=
ops
.
moe_fused_gate
(
topk_group
=
topk_group
,
router_logits
,
scoring_func
=
scoring_func
,
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
)
num_expert_group
,
topk_group
,
top_k
,
routed_scaling_factor
=
routed_scaling_factor
,
n_share_experts_fusion
=
0
,
)
else
:
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
...
@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias
=
self
.
e_score_correction_bias
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
use_nn_moe
=
self
.
use_nn_moe
)
)
if
self
.
dp_size
>
1
:
if
self
.
dp_size
>
1
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
0b229519
...
@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module):
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,)
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment