Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e887d2e
Unverified
Commit
3e887d2e
authored
May 03, 2025
by
Caleb_Du
Committed by
GitHub
May 02, 2025
Browse files
permute/unpermute kernel for moe optimization (#14568)
Signed-off-by:
Caleb_Du
<
Caleb_Du@zju.edu.cn
>
parent
0f87d8f7
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1474 additions
and
28 deletions
+1474
-28
CMakeLists.txt
CMakeLists.txt
+13
-1
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+2
-1
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+2
-2
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+349
-0
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+133
-0
csrc/moe/permute_unpermute_kernels/dispatch.h
csrc/moe/permute_unpermute_kernels/dispatch.h
+53
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
...permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
+229
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+95
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+211
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+22
-0
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+2
-1
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+223
-0
tests/kernels/quantization/test_awq_marlin.py
tests/kernels/quantization/test_awq_marlin.py
+2
-1
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+4
-2
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+2
-2
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+9
-10
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+5
-4
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
.../model_executor/layers/fused_moe/moe_permute_unpermute.py
+116
-0
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+2
-4
No files found.
CMakeLists.txt
View file @
3e887d2e
...
@@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX)
...
@@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set
(
VLLM_TARGET_DEVICE
"cuda"
CACHE STRING
"Target device backend for vLLM"
)
set
(
VLLM_TARGET_DEVICE
"cuda"
CACHE STRING
"Target device backend for vLLM"
)
message
(
STATUS
"Build type:
${
CMAKE_BUILD_TYPE
}
"
)
message
(
STATUS
"Build type:
${
CMAKE_BUILD_TYPE
}
"
)
message
(
STATUS
"Target device:
${
VLLM_TARGET_DEVICE
}
"
)
message
(
STATUS
"Target device:
${
VLLM_TARGET_DEVICE
}
"
)
...
@@ -682,6 +681,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -682,6 +681,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif
()
endif
()
endif
()
endif
()
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
set
(
MOE_PERMUTE_SRC
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
"csrc/moe/moe_permute_unpermute_op.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
MARLIN_PERMUTE_SRC
}
"
CUDA_ARCHS
"
${
MOE_PERMUTE_ARCHS
}
"
)
list
(
APPEND VLLM_MOE_EXT_SRC
"
${
MOE_PERMUTE_SRC
}
"
)
endif
()
message
(
STATUS
"Enabling moe extension."
)
message
(
STATUS
"Enabling moe extension."
)
define_gpu_extension_target
(
define_gpu_extension_target
(
_moe_C
_moe_C
...
@@ -690,6 +700,8 @@ define_gpu_extension_target(
...
@@ -690,6 +700,8 @@ define_gpu_extension_target(
SOURCES
${
VLLM_MOE_EXT_SRC
}
SOURCES
${
VLLM_MOE_EXT_SRC
}
COMPILE_FLAGS
${
VLLM_GPU_FLAGS
}
COMPILE_FLAGS
${
VLLM_GPU_FLAGS
}
ARCHITECTURES
${
VLLM_GPU_ARCHES
}
ARCHITECTURES
${
VLLM_GPU_ARCHES
}
INCLUDE_DIRECTORIES
${
CUTLASS_INCLUDE_DIR
}
INCLUDE_DIRECTORIES
${
CUTLASS_TOOLS_UTIL_INCLUDE_DIR
}
USE_SABI 3
USE_SABI 3
WITH_SOABI
)
WITH_SOABI
)
...
...
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
View file @
3e887d2e
...
@@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
...
@@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
num_experts
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
def
run_triton_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
...
...
benchmarks/kernels/benchmark_moe.py
View file @
3e887d2e
...
@@ -115,8 +115,8 @@ def benchmark_config(config: BenchmarkConfig,
...
@@ -115,8 +115,8 @@ def benchmark_config(config: BenchmarkConfig,
from
vllm.model_executor.layers.fused_moe
import
override_config
from
vllm.model_executor.layers.fused_moe
import
override_config
with
override_config
(
config
):
with
override_config
(
config
):
if
use_deep_gemm
:
if
use_deep_gemm
:
topk_weights
,
topk_ids
=
fused_topk
(
x
,
input_gating
,
topk
,
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_
topk
(
False
)
x
,
input_gating
,
topk
,
False
)
return
fused_experts
(
return
fused_experts
(
x
,
x
,
w1
,
w1
,
...
...
benchmarks/kernels/benchmark_moe_permute_unpermute.py
0 → 100644
View file @
3e887d2e
# SPDX-License-Identifier: Apache-2.0
import
argparse
from
typing
import
Any
,
TypedDict
import
ray
import
torch
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_moe_permute
,
_moe_unpermute_and_reduce
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
*
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
def
benchmark_permute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
# output_hidden_states = torch.empty_like(hidden_states)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
align_block_size
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
benchmark_unpermute
(
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
output_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
qhidden_states
,
input_gating
,
topk
,
False
)
def
prepare
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
m_indices
)
else
:
(
permuted_qhidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
align_block_size
)
# convert to fp16/bf16 as gemm output
return
(
permuted_qhidden_states
.
to
(
dtype
),
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
def
run
(
input
:
tuple
):
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
input
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
topk_ids
,
inv_perm_idx
,
first_token_off
,
topk
,
num_experts
,
num_experts
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
input
_moe_unpermute_and_reduce
(
output_hidden_states
,
permuted_hidden_states
,
inv_perm
,
topk_weights
)
# JIT compilation & warmup
input
=
prepare
()
run
(
input
)
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
(
input
)
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_customized_permute
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
permute_time
=
benchmark_permute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
)
unpermute_time
=
benchmark_unpermute
(
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
)
return
permute_time
,
unpermute_time
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
'quantization_config'
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
'weight_block_size'
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
elif
(
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
elif
config
.
architectures
[
0
]
in
[
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
]:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
]
else
:
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
list
[
Any
])
->
list
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_customized_permute
)
for
batch_size
in
batch_sizes
])
for
batch_size
,
(
permute
,
unpermute
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
"
)
print
(
f
"Permute time:
{
permute
:.
2
f
}
us"
)
print
(
f
"Unpermute time:
{
unpermute
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-customized-permute"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
main
(
args
)
csrc/moe/moe_permute_unpermute_op.cu
0 → 100644
View file @
3e887d2e
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
void
moe_permute
(
const
torch
::
Tensor
&
input
,
// [n_token, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
token_expert_indicies
,
// [n_token, topk]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
// [topk * n_token/align_block_size_m, hidden]
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert + 1]
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
// [n_token, topk]
torch
::
Tensor
&
m_indices
)
{
// [align_expand_m]
TORCH_CHECK
(
topk_weights
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"topk_weights must be float32"
);
TORCH_CHECK
(
expert_first_token_offset
.
scalar_type
()
==
at
::
ScalarType
::
Long
,
"expert_first_token_offset must be int64"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
TORCH_CHECK
(
token_expert_indicies
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"token_expert_indicies must be int32"
);
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"src_row_id2dst_row_id_map must be int32"
);
TORCH_CHECK
(
expert_first_token_offset
.
size
(
0
)
==
n_local_expert
+
1
,
"expert_first_token_offset shape != n_local_expert+1"
)
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
token_expert_indicies
.
sizes
(),
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map"
);
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
align_block_size_value
=
align_block_size
.
has_value
()
?
align_block_size
.
value
()
:
-
1
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
long
sorter_size
=
CubKeyValueSorter
::
getWorkspaceSize
(
n_token
*
topk
,
n_expert
);
auto
sort_workspace
=
torch
::
empty
(
{
sorter_size
},
torch
::
dtype
(
torch
::
kInt8
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
dst_row_id2src_row_id_map
=
torch
::
empty_like
(
src_row_id2dst_row_id_map
);
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
CubKeyValueSorter
sorter
{};
int64_t
*
valid_num_ptr
=
nullptr
;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if
(
expert_map
.
has_value
())
{
const
int
*
expert_map_ptr
=
get_ptr
<
int
>
(
expert_map
.
value
());
valid_num_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
preprocessTopkIdLauncher
(
get_ptr
<
int
>
(
topk_ids
),
n_token
*
topk
,
expert_map_ptr
,
n_expert
,
stream
);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert
(
get_ptr
<
int
>
(
topk_ids
),
get_ptr
<
int
>
(
token_expert_indicies
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
expandInputRowsKernelLauncher
<
scalar_t
>
(
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
float
>
(
topk_weights
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int
>
(
src_row_id2dst_row_id_map
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
valid_num_ptr
,
n_hidden
,
topk
,
n_local_expert
,
align_block_size_value
,
stream
);
});
// get m_indices and update expert_first_token_offset with align block
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
if
(
align_block_size
.
has_value
())
{
// update align_expert_first_token_offset
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
}
}
void
moe_unpermute
(
const
torch
::
Tensor
&
permuted_hidden_states
,
// [n_token * topk, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
const
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
// [n_token, topk]
const
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert+1]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
torch
::
Tensor
&
hidden_states
// [n_token, hidden]
)
{
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
topk_ids
.
sizes
(),
"topk_ids shape must be same as src_row_id2dst_row_id_map"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
TORCH_CHECK
(
permuted_hidden_states
.
scalar_type
()
==
hidden_states
.
scalar_type
(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map"
);
auto
n_token
=
hidden_states
.
size
(
0
);
auto
n_hidden
=
hidden_states
.
size
(
1
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int64_t
*
valid_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
MOE_DISPATCH
(
hidden_states
.
scalar_type
(),
[
&
]
{
finalizeMoeRoutingKernelLauncher
<
scalar_t
,
scalar_t
>
(
get_ptr
<
scalar_t
>
(
permuted_hidden_states
),
get_ptr
<
scalar_t
>
(
hidden_states
),
get_ptr
<
float
>
(
topk_weights
),
get_ptr
<
int
>
(
src_row_id2dst_row_id_map
),
get_ptr
<
int
>
(
topk_ids
),
n_token
,
n_hidden
,
topk
,
valid_ptr
,
stream
);
});
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"moe_permute"
,
&
moe_permute
);
m
.
impl
(
"moe_unpermute"
,
&
moe_unpermute
);
}
\ No newline at end of file
csrc/moe/permute_unpermute_kernels/dispatch.h
0 → 100644
View file @
3e887d2e
#pragma once
#include <cuda_fp8.h>
#define MOE_SWITCH(TYPE, ...) \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
}
#define MOE_DISPATCH_CASE(enum_type, ...) \
case enum_type: { \
using scalar_t = ScalarType2CudaType<enum_type>::type; \
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
template
<
at
::
ScalarType
type
>
struct
ScalarType2CudaType
;
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float
>
{
using
type
=
float
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Half
>
{
using
type
=
half
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
BFloat16
>
{
using
type
=
__nv_bfloat16
;
};
// #if __CUDA_ARCH__ >= 890
// fp8
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float8_e5m2
>
{
using
type
=
__nv_fp8_e5m2
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float8_e4m3fn
>
{
using
type
=
__nv_fp8_e4m3
;
};
// #endif
\ No newline at end of file
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
0 → 100644
View file @
3e887d2e
#include "moe_permute_unpermute_kernel.h"
// CubKeyValueSorter definition begin
CubKeyValueSorter
::
CubKeyValueSorter
()
:
num_experts_
(
0
),
num_bits_
(
sizeof
(
int
)
*
8
)
{}
int
CubKeyValueSorter
::
expertsToBits
(
int
num_experts
)
{
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
return
static_cast
<
int
>
(
log2
(
2
*
num_experts
-
1
))
+
1
;
}
CubKeyValueSorter
::
CubKeyValueSorter
(
int
const
num_experts
)
:
num_experts_
(
num_experts
),
num_bits_
(
expertsToBits
(
num_experts
))
{}
void
CubKeyValueSorter
::
updateNumExperts
(
int
const
num_experts
)
{
num_experts_
=
num_experts
;
num_bits_
=
expertsToBits
(
num_experts
);
}
size_t
CubKeyValueSorter
::
getWorkspaceSize
(
size_t
const
num_key_value_pairs
,
int
const
num_experts
)
{
int
num_bits
=
expertsToBits
(
num_experts
);
size_t
required_storage
=
0
;
int
*
null_int
=
nullptr
;
cub
::
DeviceRadixSort
::
SortPairs
(
nullptr
,
required_storage
,
null_int
,
null_int
,
null_int
,
null_int
,
num_key_value_pairs
,
0
,
num_bits
);
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
// inputs
if
(
required_storage
==
0
)
{
required_storage
=
1
;
}
return
required_storage
;
}
void
CubKeyValueSorter
::
run
(
void
*
workspace
,
size_t
const
workspace_size
,
int
const
*
keys_in
,
int
*
keys_out
,
int
const
*
values_in
,
int
*
values_out
,
size_t
const
num_key_value_pairs
,
cudaStream_t
stream
)
{
size_t
expected_ws_size
=
getWorkspaceSize
(
num_key_value_pairs
,
num_experts_
);
size_t
actual_ws_size
=
workspace_size
;
TORCH_CHECK
(
expected_ws_size
<=
workspace_size
,
"[CubKeyValueSorter::run] The allocated workspace is too small "
"to run this problem."
);
cub
::
DeviceRadixSort
::
SortPairs
(
workspace
,
actual_ws_size
,
keys_in
,
keys_out
,
values_in
,
values_out
,
num_key_value_pairs
,
0
,
num_bits_
,
stream
);
}
// CubKeyValueSorter definition end
static
inline
size_t
pad_to_multiple_of_16
(
size_t
const
&
input
)
{
static
constexpr
int
ALIGNMENT
=
16
;
return
ALIGNMENT
*
((
input
+
ALIGNMENT
-
1
)
/
ALIGNMENT
);
}
template
<
class
T
>
__device__
inline
int64_t
findTotalEltsLessThanTarget
(
T
const
*
sorted_indices
,
int64_t
const
arr_length
,
T
const
target
)
{
int64_t
low
=
0
,
high
=
arr_length
-
1
,
target_location
=
-
1
;
while
(
low
<=
high
)
{
int64_t
mid
=
(
low
+
high
)
/
2
;
if
(
sorted_indices
[
mid
]
>=
target
)
{
high
=
mid
-
1
;
}
else
{
low
=
mid
+
1
;
target_location
=
mid
;
}
}
return
target_location
+
1
;
}
// Calculates the start offset of the tokens for a given expert. The last
// element is the total number of valid tokens
__global__
void
computeExpertFirstTokenOffsetKernel
(
int
const
*
sorted_experts
,
int64_t
const
sorted_experts_len
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
)
{
// First, compute the global tid. We only need 1 thread per expert.
int
const
expert
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// Note that expert goes [0, num_experts] (inclusive) because we want a count
// for the total number of active tokens at the end of the scan.
if
(
expert
>=
num_experts
+
1
)
{
return
;
}
expert_first_token_offset
[
expert
]
=
findTotalEltsLessThanTarget
(
sorted_experts
,
sorted_experts_len
,
expert
);
}
void
computeExpertFirstTokenOffset
(
int
const
*
sorted_indices
,
int
const
total_indices
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
,
cudaStream_t
stream
)
{
int
const
num_entries
=
num_experts
+
1
;
int
const
threads
=
std
::
min
(
1024
,
num_entries
);
int
const
blocks
=
(
num_entries
+
threads
-
1
)
/
threads
;
computeExpertFirstTokenOffsetKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
sorted_indices
,
total_indices
,
num_experts
,
expert_first_token_offset
);
}
void
sortAndScanExpert
(
int
*
expert_for_source_row
,
const
int
*
source_rows
,
int
*
permuted_experts
,
int
*
permuted_rows
,
int64_t
*
expert_first_token_offset
,
int
num_rows
,
int
num_experts
,
int
num_experts_per_node
,
int
k
,
CubKeyValueSorter
&
sorter
,
void
*
sorter_ws
,
cudaStream_t
stream
)
{
int64_t
const
expanded_num_rows
=
static_cast
<
int64_t
>
(
k
)
*
num_rows
;
// We need to use the full num_experts because that is the sentinel value used
// by topk for disabled experts
sorter
.
updateNumExperts
(
num_experts
);
size_t
const
sorter_ws_size_bytes
=
pad_to_multiple_of_16
(
sorter
.
getWorkspaceSize
(
expanded_num_rows
,
num_experts
));
sorter
.
run
((
void
*
)
sorter_ws
,
sorter_ws_size_bytes
,
expert_for_source_row
,
permuted_experts
,
source_rows
,
permuted_rows
,
expanded_num_rows
,
stream
);
computeExpertFirstTokenOffset
(
permuted_experts
,
expanded_num_rows
,
num_experts_per_node
,
expert_first_token_offset
,
stream
);
}
__global__
void
preprocessTopkIdKernel
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
)
{
auto
tidx
=
threadIdx
.
x
;
auto
bidx
=
blockIdx
.
x
;
auto
lidx
=
tidx
&
31
;
auto
widx
=
tidx
>>
5
;
auto
warp_count
=
(
blockDim
.
x
+
31
)
>>
5
;
auto
offset
=
bidx
*
blockDim
.
x
;
auto
bound
=
min
(
offset
+
blockDim
.
x
,
size
);
extern
__shared__
int
smem_expert_map
[];
// store expert_map in smem
for
(
int
i
=
tidx
;
i
<
num_experts
;
i
+=
blockDim
.
x
)
{
smem_expert_map
[
i
]
=
expert_map_ptr
[
i
];
}
__syncthreads
();
// query global expert id in expert map.
// if global expert id = -1 in exert map, plus n_expert
// else set global expert id = exert map[global expert id]
if
(
offset
+
tidx
<
bound
)
{
auto
topk_id
=
topk_id_ptr
[
offset
+
tidx
];
auto
local_expert_idx
=
smem_expert_map
[
topk_id
];
if
(
local_expert_idx
==
-
1
)
{
topk_id
+=
num_experts
;
}
else
{
topk_id
=
local_expert_idx
;
}
__syncwarp
();
topk_id_ptr
[
offset
+
tidx
]
=
topk_id
;
}
}
void
preprocessTopkIdLauncher
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
,
cudaStream_t
stream
)
{
int
block
=
std
::
min
(
size
,
1024
);
int
grid
=
(
size
+
block
-
1
)
/
block
;
int
smem_size
=
(
num_experts
)
*
sizeof
(
int
);
preprocessTopkIdKernel
<<<
grid
,
block
,
smem_size
,
stream
>>>
(
topk_id_ptr
,
size
,
expert_map_ptr
,
num_experts
);
}
template
<
bool
ALIGN_BLOCK_SIZE
>
__global__
void
getMIndicesKernel
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
const
int
num_local_expert
,
const
int
align_block_size
)
{
int
eidx
=
blockIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
extern
__shared__
int64_t
smem_expert_first_token_offset
[];
for
(
int
i
=
tidx
;
i
<=
num_local_expert
;
i
+=
blockDim
.
x
)
{
smem_expert_first_token_offset
[
tidx
]
=
__ldg
(
expert_first_token_offset
+
i
);
}
__syncthreads
();
auto
last_token_offset
=
smem_expert_first_token_offset
[
eidx
+
1
];
auto
first_token_offset
=
smem_expert_first_token_offset
[
eidx
];
int
n_token_in_expert
=
last_token_offset
-
first_token_offset
;
if
constexpr
(
ALIGN_BLOCK_SIZE
)
{
n_token_in_expert
=
(
n_token_in_expert
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
// round up to ALIGN_BLOCK_SIZE
int64_t
accumulate_align_offset
=
0
;
for
(
int
i
=
1
;
i
<=
eidx
+
1
;
i
++
)
{
int
n_token
=
smem_expert_first_token_offset
[
i
]
-
smem_expert_first_token_offset
[
i
-
1
];
accumulate_align_offset
=
accumulate_align_offset
+
(
n_token
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
if
(
i
==
eidx
)
{
first_token_offset
=
accumulate_align_offset
;
}
// last block store align_expert_first_token_offset
if
(
eidx
==
num_local_expert
-
1
&&
threadIdx
.
x
==
0
)
{
align_expert_first_token_offset
[
i
]
=
accumulate_align_offset
;
}
}
}
for
(
int
idx
=
tidx
;
idx
<
n_token_in_expert
;
idx
+=
blockDim
.
x
)
{
// update m_indice with expert id
m_indices
[
first_token_offset
+
idx
]
=
eidx
;
}
}
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
)
{
int
block
=
256
;
int
grid
=
num_local_expert
;
int
smem_size
=
sizeof
(
int64_t
)
*
(
num_local_expert
+
1
);
if
(
align_block_size
==
-
1
)
{
getMIndicesKernel
<
false
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
else
{
getMIndicesKernel
<
true
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
}
\ No newline at end of file
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
0 → 100644
View file @
3e887d2e
#pragma once
// reference from tensorrt_llm moe kernel implementation archive in
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include "dispatch.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#include "cutlass/numeric_size.h"
#include "cutlass/array.h"
template
<
typename
T
>
inline
T
*
get_ptr
(
torch
::
Tensor
&
t
)
{
return
reinterpret_cast
<
T
*>
(
t
.
data_ptr
());
}
template
<
typename
T
>
inline
const
T
*
get_ptr
(
const
torch
::
Tensor
&
t
)
{
return
reinterpret_cast
<
const
T
*>
(
t
.
data_ptr
());
}
class
CubKeyValueSorter
{
public:
CubKeyValueSorter
();
CubKeyValueSorter
(
int
const
num_experts
);
void
updateNumExperts
(
int
const
num_experts
);
static
size_t
getWorkspaceSize
(
size_t
const
num_key_value_pairs
,
int
const
num_experts
);
void
run
(
void
*
workspace
,
size_t
const
workspace_size
,
int
const
*
keys_in
,
int
*
keys_out
,
int
const
*
values_in
,
int
*
values_out
,
size_t
const
num_key_value_pairs
,
cudaStream_t
stream
);
private:
static
int
expertsToBits
(
int
experts
);
int
num_experts_
;
int
num_bits_
;
};
void
computeExpertFirstTokenOffset
(
int
const
*
sorted_indices
,
int
const
total_indices
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
,
cudaStream_t
stream
);
void
sortAndScanExpert
(
int
*
expert_for_source_row
,
const
int
*
source_rows
,
int
*
permuted_experts
,
int
*
permuted_rows
,
int64_t
*
expert_first_token_offset
,
int
num_rows
,
int
num_experts
,
int
num_experts_per_node
,
int
k
,
CubKeyValueSorter
&
sorter
,
void
*
sorter_ws
,
cudaStream_t
stream
);
template
<
typename
T
>
void
expandInputRowsKernelLauncher
(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
const
float
*
unpermuted_scales
,
int
*
sorted_experts
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int64_t
*
expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int
num_local_experts
,
const
int
&
align_block_size
,
cudaStream_t
stream
);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template
<
typename
T
,
typename
OutputType
,
bool
CHECK_SKIPPED
>
__global__
void
finalizeMoeRoutingKernel
(
T
const
*
expanded_permuted_rows
,
OutputType
*
reduced_unpermuted_output
,
float
const
*
scales
,
int
const
*
expanded_source_row_to_expanded_dest_row
,
int
const
*
expert_for_source_row
,
int64_t
const
orig_cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
);
template
<
class
T
,
class
OutputType
>
void
finalizeMoeRoutingKernelLauncher
(
T
const
*
expanded_permuted_rows
,
OutputType
*
reduced_unpermuted_output
,
float
const
*
scales
,
int
const
*
expanded_source_row_to_expanded_dest_row
,
int
const
*
expert_for_source_row
,
int64_t
const
num_rows
,
int64_t
const
cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
,
cudaStream_t
stream
);
void
preprocessTopkIdLauncher
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
,
cudaStream_t
stream
);
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
);
#include "moe_permute_unpermute_kernel.inl"
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
0 → 100644
View file @
3e887d2e
#pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
extern __shared__ int64_t smem_expert_first_token_offset[];
int64_t align_expanded_row_accumulate = 0;
if constexpr (ALIGN_BLOCK_SIZE) {
// load g2s
for (int idx = threadIdx.x; idx < num_local_experts + 1;
idx += blockDim.x) {
smem_expert_first_token_offset[idx] =
__ldg(expert_first_token_offset + idx);
}
__syncthreads();
int lane_idx = threadIdx.x & 31;
if (lane_idx == 0) {
// set token_offset_in_expert = 0 if this expert is not local expert
int token_offset_in_expert =
expert_id >= num_local_experts
? 0
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
int64_t accumulate_align_offset = 0;
#pragma unroll 1
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
smem_expert_first_token_offset[eidx - 1];
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
}
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
}
// lane0 shuffle broadcast align_expanded_dest_row
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row);
}
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
int64_t const source_k_rank = expanded_source_row / num_rows;
int64_t const source_row = expanded_source_row % num_rows;
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
align_block_size);
}
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input) {
using Type = typename U::Element;
static_assert(T::kElements == U::kElements);
U u;
#pragma unroll
for (int i = 0; i < U::kElements; i++) {
u[i] = static_cast<Type>(input[i]);
}
return u;
}
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr;
// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
cutlass::sizeof_bits<T>::value);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
auto const* expanded_permuted_rows_v =
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
#pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
ComputeElem thread_output;
thread_output.fill(0);
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx * num_rows;
int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue;
}
auto const* expanded_permuted_rows_row_ptr =
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
int64_t const expert_idx = expert_for_source_row[k_offset];
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result);
}
OutputElem output_elem =
arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem;
}
}
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream) {
int64_t const blocks = num_rows;
int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr;
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
&finalizeMoeRoutingKernel<T, OutputType, true>};
auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
num_valid_ptr);
}
csrc/moe/torch_bindings.cpp
View file @
3e887d2e
...
@@ -53,7 +53,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -53,7 +53,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int size_m, int size_n, int size_k,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"
);
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"
);
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
m
.
def
(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
"Tensor token_expert_indicies, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
"m_indices)->()"
);
m
.
def
(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()"
);
// conditionally compiled so impl registration is in source file
// conditionally compiled so impl registration is in source file
#endif
#endif
...
...
tests/kernels/moe/test_moe.py
View file @
3e887d2e
...
@@ -420,7 +420,8 @@ def test_fused_marlin_moe(
...
@@ -420,7 +420,8 @@ def test_fused_marlin_moe(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
,
topk
,
False
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
0 → 100644
View file @
3e887d2e
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE permute/unpermute kernel
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
"""
from
typing
import
Optional
import
numpy
as
np
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.layer
import
determine_expert_map
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_permute
,
moe_unpermute
)
from
vllm.platforms
import
current_platform
NUM_EXPERTS
=
[
16
,
64
]
TOP_KS
=
[
2
,
4
,
6
,
8
]
EP_SIZE
=
[
1
,
4
,
16
]
current_platform
.
seed_everything
(
0
)
def
torch_permute
(
hidden_states
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
,
n_local_expert
:
int
,
start_expert
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
align_block_size
:
Optional
[
int
]
=
None
,
fill_invalid_expert
:
int
=
-
1
)
->
list
[
torch
.
Tensor
]:
n_token
,
n_hidden
=
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
1
]
if
expert_map
is
not
None
:
is_local_expert
=
(
expert_map
[
topk_ids
]
!=
-
1
)
not_local_expert
=
(
expert_map
[
topk_ids
]
==
-
1
)
topk_ids
=
is_local_expert
*
(
topk_ids
-
start_expert
)
+
not_local_expert
*
(
topk_ids
+
n_expert
)
sorted_topk_ids
,
sorted_indices
=
torch
.
sort
(
topk_ids
.
flatten
(),
stable
=
True
)
dst_row_id2src_row_id_map
=
token_expert_indices
.
flatten
()[
sorted_indices
]
expert_first_token_offset
=
torch
.
zeros
(
n_local_expert
+
1
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
idx
=
0
for
i
in
range
(
0
,
n_local_expert
):
cnt
=
0
while
idx
<
sorted_topk_ids
.
numel
()
and
sorted_topk_ids
[
idx
]
==
i
:
cnt
+=
1
idx
+=
1
expert_first_token_offset
[
i
+
1
]
=
expert_first_token_offset
[
i
]
+
cnt
_
,
src2dst_idx
=
torch
.
sort
(
dst_row_id2src_row_id_map
)
valid_row_idx
=
[]
if
align_block_size
is
None
:
permuted_hidden_states
=
hidden_states
[
dst_row_id2src_row_id_map
%
n_token
,
...]
permuted_row_size
=
permuted_hidden_states
.
shape
[
0
]
m_indices
=
torch
.
empty
(
permuted_row_size
,
device
=
"cuda"
,
dtype
=
torch
.
int32
).
fill_
(
fill_invalid_expert
)
for
i
in
range
(
1
,
n_local_expert
+
1
):
first_token_offset
=
expert_first_token_offset
[
i
-
1
]
last_token_offset
=
expert_first_token_offset
[
i
]
m_indices
[
first_token_offset
:
last_token_offset
]
=
i
-
1
src_row_id2dst_row_id_map
=
torch
.
arange
(
0
,
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[
src2dst_idx
].
reshape
((
n_token
,
topk
))
valid_row_idx
+=
[
i
for
i
in
range
(
expert_first_token_offset
[
-
1
])]
return
[
permuted_hidden_states
,
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
m_indices
,
valid_row_idx
]
else
:
permuted_row_size
=
(
topk
*
n_token
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
permuted_hidden_states
=
torch
.
empty
((
permuted_row_size
,
n_hidden
),
device
=
"cuda"
,
dtype
=
hidden_states
.
dtype
)
align_src_row_id2dst_row_id
=
torch
.
empty
(
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
align_expert_first_token_offset
=
torch
.
zeros_like
(
expert_first_token_offset
)
m_indices
=
torch
.
empty
(
permuted_row_size
,
device
=
"cuda"
,
dtype
=
torch
.
int32
).
fill_
(
fill_invalid_expert
)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for
i
in
range
(
1
,
n_local_expert
+
1
):
first_token_offset
=
expert_first_token_offset
[
i
-
1
]
last_token_offset
=
expert_first_token_offset
[
i
]
n_token_in_expert
=
last_token_offset
-
first_token_offset
align_expert_first_token_offset
[
i
]
=
align_expert_first_token_offset
[
i
-
1
]
+
(
n_token_in_expert
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
align_first_token_offset
=
align_expert_first_token_offset
[
i
-
1
]
align_last_token_offset
=
align_expert_first_token_offset
[
i
]
dst_row_id2src_row_id_in_expert
=
dst_row_id2src_row_id_map
[
first_token_offset
:
first_token_offset
+
n_token_in_expert
]
%
n_token
# store token in current expert with align_first_token_offset
permuted_hidden_states
[
align_first_token_offset
:
\
align_first_token_offset
+
n_token_in_expert
,
\
...]
=
hidden_states
[
\
dst_row_id2src_row_id_in_expert
,
...]
# set current expert m_indices
m_indices
[
align_first_token_offset
:
align_last_token_offset
]
=
i
-
1
valid_row_idx
+=
[
i
for
i
in
range
(
align_first_token_offset
,
align_first_token_offset
+
n_token_in_expert
)
]
# get align_src_row_id2dst_row_id
for
i
in
range
(
n_token
*
topk
):
eid
=
sorted_topk_ids
[
i
]
if
(
eid
>=
n_local_expert
):
# check token not in local expert
align_src_row_id2dst_row_id
[
i
]
=
align_expert_first_token_offset
[
-
1
]
continue
first_token_offset
=
expert_first_token_offset
[
eid
]
align_first_token_offset
=
align_expert_first_token_offset
[
eid
]
token_offset
=
i
-
first_token_offset
align_src_row_id2dst_row_id
[
i
]
=
align_first_token_offset
+
token_offset
align_src_row_id2dst_row_id
=
align_src_row_id2dst_row_id
[
\
src2dst_idx
].
reshape
((
n_token
,
topk
))
return
[
permuted_hidden_states
,
align_expert_first_token_offset
,
align_src_row_id2dst_row_id
,
m_indices
,
valid_row_idx
]
def
torch_unpermute
(
permuted_hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
src_row_id2dst_row_id_map
:
torch
.
Tensor
,
valid_row_idx
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
)
->
torch
.
Tensor
:
# ignore invalid row
mask
=
torch
.
zeros
(
permuted_hidden_states
.
shape
[
0
],
dtype
=
bool
,
device
=
"cuda"
)
mask
[
valid_row_idx
]
=
True
permuted_hidden_states
[
~
mask
]
=
0
idx
=
src_row_id2dst_row_id_map
.
flatten
()[
token_expert_indices
.
flatten
()].
reshape
(
token_expert_indices
.
shape
)
output
=
permuted_hidden_states
[
idx
,
...]
*
topk_weights
[...,
None
]
output
=
output
.
sum
(
dim
=
1
).
to
(
permuted_hidden_states
.
dtype
)
return
output
@
pytest
.
mark
.
parametrize
(
"n_token"
,
[
1
,
33
,
64
,
222
,
1024
,
2048
,
3000
,
5000
])
@
pytest
.
mark
.
parametrize
(
"n_hidden"
,
[
2048
,
4096
,
7168
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"align_block_size"
,
[
None
,
128
])
def
test_moe_permute_unpermute
(
n_token
:
int
,
n_hidden
:
int
,
topk
:
int
,
n_expert
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
align_block_size
:
Optional
[
int
]):
fill_invalid_expert
=
0
ep_rank
=
np
.
random
.
randint
(
0
,
ep_size
)
expert_map
=
None
n_local_expert
=
n_expert
if
(
ep_size
!=
1
):
n_local_expert
,
expert_map
=
determine_expert_map
(
ep_size
,
ep_rank
,
n_expert
)
expert_map
=
expert_map
.
cuda
()
start_expert
=
n_local_expert
*
ep_rank
current_platform
.
seed_everything
(
0
)
hidden_states
=
torch
.
randn
((
n_token
,
n_hidden
),
device
=
"cuda"
).
to
(
dtype
)
gating_output
=
torch
.
randn
((
n_token
,
n_expert
),
device
=
"cuda"
).
to
(
dtype
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
False
)
gold0
,
gold1
,
gold2
,
gold3
,
valid_row_idx
=
torch_permute
(
hidden_states
,
topk_ids
,
token_expert_indices
,
topk
,
n_expert
,
n_local_expert
,
start_expert
,
expert_map
=
expert_map
,
align_block_size
=
align_block_size
,
fill_invalid_expert
=
fill_invalid_expert
)
result0
,
result1
,
result2
,
result3
=
moe_permute
(
hidden_states
,
topk_weights
,
topk_ids
,
token_expert_indices
,
topk
,
n_expert
,
n_local_expert
,
expert_map
,
align_block_size
,
fill_invalid_expert
)
# check expert_first_token_offset
torch
.
testing
.
assert_close
(
gold1
,
result1
,
atol
=
0
,
rtol
=
0
)
# check src_row_id2dst_row_id_map
torch
.
testing
.
assert_close
(
gold2
,
result2
,
atol
=
0
,
rtol
=
0
)
# check mindice
torch
.
testing
.
assert_close
(
gold3
,
result3
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
gold0
[
valid_row_idx
],
result0
[
valid_row_idx
],
atol
=
0
,
rtol
=
0
)
# add a random tensor to simulate group gemm
result0
=
0.5
*
result0
+
torch
.
randn_like
(
result0
)
result4
=
moe_unpermute
(
result0
,
topk_weights
,
topk_ids
,
result2
,
result1
,
topk
,
n_expert
,
n_local_expert
)
gold4
=
torch_unpermute
(
result0
,
topk_weights
,
topk_ids
,
token_expert_indices
,
result2
,
valid_row_idx
,
topk
,
n_local_expert
)
# check unpermuted hidden
torch
.
testing
.
assert_close
(
result4
,
gold4
,
atol
=
2e-2
,
rtol
=
0
)
tests/kernels/quantization/test_awq_marlin.py
View file @
3e887d2e
...
@@ -84,7 +84,8 @@ def test_fused_marlin_moe_awq(
...
@@ -84,7 +84,8 @@ def test_fused_marlin_moe_awq(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
,
topk
,
False
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
a
,
qweight1
,
qweight1
,
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
3e887d2e
...
@@ -338,7 +338,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
...
@@ -338,7 +338,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
M
,
K
=
a
.
shape
M
,
K
=
a
.
shape
N
=
w2
.
shape
[
-
1
]
N
=
w2
.
shape
[
-
1
]
topk_weight
,
topk_ids
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
topk_weight
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
...
@@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
...
@@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
topk
,
block_size
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
out
=
deep_gemm_moe_fp8
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
out
=
deep_gemm_moe_fp8
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
3e887d2e
...
@@ -71,8 +71,8 @@ def single_marlin_moe(
...
@@ -71,8 +71,8 @@ def single_marlin_moe(
E
=
w
.
shape
[
0
]
E
=
w
.
shape
[
0
]
N
=
w
.
shape
[
2
]
//
(
num_bits
//
2
)
N
=
w
.
shape
[
2
]
//
(
num_bits
//
2
)
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_
topk
(
renormalize
)
hidden_states
,
gating_output
,
topk
,
renormalize
)
# This might not be an optimal config for a single MMM
# This might not be an optimal config for a single MMM
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
3e887d2e
...
@@ -854,7 +854,7 @@ def fused_topk(
...
@@ -854,7 +854,7 @@ def fused_topk(
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
"Number of tokens mismatch"
)
...
@@ -868,20 +868,19 @@ def fused_topk(
...
@@ -868,20 +868,19 @@ def fused_topk(
topk
,
topk
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
token_expert_indic
i
es
=
torch
.
empty
(
M
,
token_expert_indices
=
torch
.
empty
(
M
,
topk
,
topk
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
gating_output_float
=
gating_output
.
float
()
# TODO(woosuk): Optimize this.
gating_output_float
=
gating_output
.
float
()
# TODO(woosuk): Optimize this.
topk_func
=
dispatch_topk_func
()
topk_func
=
dispatch_topk_func
()
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indic
i
es
,
token_expert_indices
,
gating_output_float
,
renormalize
)
gating_output_float
,
renormalize
)
del
token_expert_indicies
# Not used. Will be used in the future.
return
topk_weights
,
topk_ids
,
token_expert_indices
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 and Deepseek-V3 model
# This is used by the Deepseek-V2 and Deepseek-V3 model
...
@@ -1510,8 +1509,8 @@ def fused_moe(
...
@@ -1510,8 +1509,8 @@ def fused_moe(
topk
,
renormalize
,
topk
,
renormalize
,
num_expert_group
,
topk_group
)
num_expert_group
,
topk_group
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_
topk
(
renormalize
)
hidden_states
,
gating_output
,
topk
,
renormalize
)
else
:
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
hidden_states
,
gating_output
,
topk
,
renormalize
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
3e887d2e
...
@@ -801,10 +801,11 @@ class FusedMoE(torch.nn.Module):
...
@@ -801,10 +801,11 @@ class FusedMoE(torch.nn.Module):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
gating_output
=
router_logits
,
hidden_states
=
hidden_states
,
topk
=
top_k
,
gating_output
=
router_logits
,
renormalize
=
renormalize
)
topk
=
top_k
,
renormalize
=
renormalize
)
else
:
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
...
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
0 → 100644
View file @
3e887d2e
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
import
torch
def
moe_permute
(
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
,
n_local_expert
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
align_block_size
:
Optional
[
int
]
=
None
,
fill_invalid_expert
:
int
=
-
1
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function expands and permutes activation to gather uncontinuous tokens
for each expert.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- topk_weights (torch.Tensor): topk expert route weight for each token.
- topk_ids (torch.Tensor): topk expert route id for each token.
- token_expert_indices (torch.Tensor): indice for expanded hidden.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
Returns:
- permuted_hidden_states (torch.Tensor): permuted activation.
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
"""
n_token
,
n_hidden
=
hidden_states
.
shape
assert
(
n_hidden
*
hidden_states
.
element_size
()
)
%
16
==
0
,
"permue kernel need hidden dim align to 16B"
permuted_row_size
=
n_token
*
topk
if
align_block_size
is
not
None
:
permuted_row_size
=
(
permuted_row_size
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
permuted_hidden_states
=
torch
.
empty
(
(
permuted_row_size
,
n_hidden
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
m_indices
=
torch
.
full
((
permuted_row_size
,
),
fill_invalid_expert
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
expert_first_token_offset
=
torch
.
empty
(
n_local_expert
+
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
src_row_id2dst_row_id_map
=
torch
.
empty
((
n_token
,
topk
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
torch
.
ops
.
_moe_C
.
moe_permute
(
hidden_states
,
topk_weights
,
topk_ids
,
token_expert_indices
,
expert_map
,
n_expert
,
n_local_expert
,
topk
,
align_block_size
,
permuted_hidden_states
,
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
m_indices
)
return
(
permuted_hidden_states
,
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
m_indices
)
def
moe_unpermute
(
permuted_hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
src_row_id2dst_row_id_map
:
torch
.
Tensor
,
expert_first_token_offset
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
,
n_local_expert
:
int
,
)
->
torch
.
Tensor
:
"""
This function expands and permutes activation to gathering uncontinuous
tokens for each expert.
Parameters:
- permuted_hidden_states (torch.Tensor): permuted activation.
- topk_weights (torch.Tensor): topk expert route weight for each token.
- topk_ids (torch.Tensor): topk expert route id for each token.
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for grouped gemm.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
Returns:
- hidden_states (torch.Tensor): The reduced and unpermuted activation
tensor.
"""
n_token
,
n_hidden
=
topk_weights
.
shape
[
0
],
permuted_hidden_states
.
shape
[
-
1
]
assert
(
n_hidden
*
permuted_hidden_states
.
element_size
()
)
%
16
==
0
,
"unpermue kernel need hidden dim align to 16B"
hidden_states
=
torch
.
empty
((
n_token
,
n_hidden
),
dtype
=
permuted_hidden_states
.
dtype
,
device
=
permuted_hidden_states
.
device
)
torch
.
ops
.
_moe_C
.
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
topk_ids
,
src_row_id2dst_row_id_map
,
expert_first_token_offset
,
n_expert
,
n_local_expert
,
topk
,
hidden_states
)
return
hidden_states
vllm/model_executor/models/arctic.py
View file @
3e887d2e
...
@@ -175,10 +175,8 @@ class ArcticMoE(nn.Module):
...
@@ -175,10 +175,8 @@ class ArcticMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
do_normalize
=
self
.
top_k
>
1
do_normalize
=
self
.
top_k
>
1
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
router_logits
,
hidden_states
,
router_logits
,
self
.
top_k
,
renormalize
=
do_normalize
)
self
.
top_k
,
renormalize
=
do_normalize
)
# topk_ids: (num_tokens, k)
# topk_ids: (num_tokens, k)
if
self
.
is_quant
:
if
self
.
is_quant
:
if
2
*
num_tokens
<=
self
.
num_experts
:
if
2
*
num_tokens
<=
self
.
num_experts
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment