Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f9c069c8
Unverified
Commit
f9c069c8
authored
May 14, 2025
by
bnellnm
Committed by
GitHub
May 14, 2025
Browse files
Modularize fused experts and integrate PPLX kernels (#15956)
parent
418d2f8b
Changes
42
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2425 additions
and
539 deletions
+2425
-539
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+3
-0
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+14
-0
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+4
-4
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+45
-18
examples/offline_inference/data_parallel.py
examples/offline_inference/data_parallel.py
+16
-6
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+114
-0
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+21
-25
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+51
-42
tests/kernels/moe/test_pplx_moe.py
tests/kernels/moe/test_pplx_moe.py
+691
-0
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
+20
-14
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+10
-10
tests/kernels/quantization/test_block_int8.py
tests/kernels/quantization/test_block_int8.py
+4
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+51
-2
vllm/distributed/utils.py
vllm/distributed/utils.py
+6
-7
vllm/forward_context.py
vllm/forward_context.py
+4
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+3
-2
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+195
-108
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+131
-198
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+755
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+287
-101
No files found.
csrc/activation_kernels.cu
View file @
f9c069c8
...
@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
...
@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \
...
...
csrc/dispatch_utils.h
View file @
f9c069c8
...
@@ -65,5 +65,19 @@
...
@@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
csrc/moe/moe_align_sum_kernels.cu
View file @
f9c069c8
...
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
}
if
(
use_global_memory
)
{
if
(
use_global_memory
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_global_mem_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_global_mem_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
// tensors
...
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer
.
data_ptr
<
int32_t
>
());
cumsum_buffer
.
data_ptr
<
int32_t
>
());
});
});
}
else
if
(
use_i16
)
{
}
else
if
(
use_i16
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// set dynamic shared mem
// set dynamic shared mem
auto
kernel
=
auto
kernel
=
...
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids
.
numel
());
topk_ids
.
numel
());
});
});
}
else
{
}
else
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
auto
kernel
=
vllm
::
moe
::
moe_align_block_size_kernel
<
scalar_t
,
int32_t
>
;
vllm
::
moe
::
moe_align_block_size_kernel
<
scalar_t
,
int32_t
>
;
...
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK
(
num_experts
==
256
,
TORCH_CHECK
(
num_experts
==
256
,
"sgl_moe_align_block_size kernel only supports deepseek v3."
);
"sgl_moe_align_block_size kernel only supports deepseek v3."
);
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `cumsum` tensors
// calc needed amount of shared mem for `cumsum` tensors
auto
options_int
=
auto
options_int
=
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
f9c069c8
...
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
...
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}
}
}
template
<
int
TPB
>
template
<
int
TPB
,
typename
IndType
>
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
int
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
{
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
...
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
...
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
2) This implementation assumes k is small, but will work for any k.
*/
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
>
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
typename
IndType
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
int
*
indices
,
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
{
// We begin by enforcing compile time assertions and setting up compile time constants.
// We begin by enforcing compile time assertions and setting up compile time constants.
...
@@ -397,8 +405,8 @@ struct TopkConstants
...
@@ -397,8 +405,8 @@ struct TopkConstants
};
};
}
// namespace detail
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
>
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
typename
IndType
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
int
*
indices
,
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
{
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
...
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
...
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
stream);
template
<
typename
IndType
>
void
topkGatingSoftmaxKernelLauncher
(
void
topkGatingSoftmaxKernelLauncher
(
const
float
*
gating_output
,
const
float
*
gating_output
,
float
*
topk_weights
,
float
*
topk_weights
,
int
*
topk_indicies
,
IndType
*
topk_indicies
,
int
*
token_expert_indices
,
int
*
token_expert_indices
,
float
*
softmax_workspace
,
float
*
softmax_workspace
,
const
int
num_tokens
,
const
int
num_tokens
,
...
@@ -493,6 +502,9 @@ void topk_softmax(
...
@@ -493,6 +502,9 @@ void topk_softmax(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
...
@@ -503,4 +515,19 @@ void topk_softmax(
...
@@ -503,4 +515,19 @@ void topk_softmax(
num_experts
,
num_experts
,
topk
,
topk
,
stream
);
stream
);
}
else
{
assert
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
UInt32
);
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
uint32_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
}
}
examples/offline_inference/data_parallel.py
View file @
f9c069c8
...
@@ -65,11 +65,17 @@ def parse_args():
...
@@ -65,11 +65,17 @@ def parse_args():
type
=
int
,
type
=
int
,
default
=
0
,
default
=
0
,
help
=
"Master node port"
)
help
=
"Master node port"
)
parser
.
add_argument
(
"--enforce-eager"
,
action
=
'store_true'
,
help
=
"Enforce eager mode execution."
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
'store_true'
,
help
=
"Trust remote code."
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
main
(
model
,
dp_size
,
local_dp_rank
,
global_dp_rank
,
dp_master_ip
,
def
main
(
model
,
dp_size
,
local_dp_rank
,
global_dp_rank
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
):
dp_master_port
,
GPUs_per_dp_rank
,
enforce_eager
,
trust_remote_code
):
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
global_dp_rank
)
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
global_dp_rank
)
os
.
environ
[
"VLLM_DP_RANK_LOCAL"
]
=
str
(
local_dp_rank
)
os
.
environ
[
"VLLM_DP_RANK_LOCAL"
]
=
str
(
local_dp_rank
)
os
.
environ
[
"VLLM_DP_SIZE"
]
=
str
(
dp_size
)
os
.
environ
[
"VLLM_DP_SIZE"
]
=
str
(
dp_size
)
...
@@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
...
@@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
max_tokens
=
[
16
,
20
][
global_dp_rank
%
2
])
max_tokens
=
[
16
,
20
][
global_dp_rank
%
2
])
# Create an LLM.
# Create an LLM.
llm
=
LLM
(
model
=
model
,
llm
=
LLM
(
model
=
model
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
enforce_eager
=
True
,
enforce_eager
=
enforce_eager
,
enable_expert_parallel
=
True
)
enable_expert_parallel
=
True
,
trust_remote_code
=
trust_remote_code
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
# Print the outputs.
for
i
,
output
in
enumerate
(
outputs
):
for
i
,
output
in
enumerate
(
outputs
):
...
@@ -155,7 +164,8 @@ if __name__ == "__main__":
...
@@ -155,7 +164,8 @@ if __name__ == "__main__":
proc
=
Process
(
target
=
main
,
proc
=
Process
(
target
=
main
,
args
=
(
args
.
model
,
dp_size
,
local_dp_rank
,
args
=
(
args
.
model
,
dp_size
,
local_dp_rank
,
global_dp_rank
,
dp_master_ip
,
dp_master_port
,
global_dp_rank
,
dp_master_ip
,
dp_master_port
,
tp_size
))
tp_size
,
args
.
enforce_eager
,
args
.
trust_remote_code
))
proc
.
start
()
proc
.
start
()
procs
.
append
(
proc
)
procs
.
append
(
proc
)
exit_code
=
0
exit_code
=
0
...
...
tests/kernels/moe/test_batched_moe.py
0 → 100644
View file @
f9c069c8
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
pytest
import
torch
import
triton.language
as
tl
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
invoke_moe_batched_triton_kernel
)
@
dataclass
class
BatchedMMConfig
:
dtype
:
torch
.
dtype
num_experts
:
int
max_tokens_per_expert
:
int
K
:
int
N
:
int
@
dataclass
class
BatchedMMTensors
:
A
:
torch
.
Tensor
# [E, max_tokens, K]
B
:
torch
.
Tensor
# [E, K, N] - column major
C
:
torch
.
Tensor
# [E, max_tokens, N]
num_expert_tokens
:
torch
.
Tensor
# [E]
@
staticmethod
def
make_tensors
(
config
:
BatchedMMConfig
):
A
=
torch
.
randn
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
K
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
/
10
B
=
torch
.
randn
((
config
.
num_experts
,
config
.
N
,
config
.
K
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
C
=
torch
.
zeros
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
N
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
num_expert_tokens
=
torch
.
randint
(
low
=
0
,
high
=
config
.
max_tokens_per_expert
,
size
=
(
config
.
num_experts
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
return
BatchedMMTensors
(
A
,
B
,
C
,
num_expert_tokens
)
def
ref_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
num_expert_tokens
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_expert_tokens_cpu
=
num_expert_tokens
.
clone
()
num_expert_tokens_cpu
=
num_expert_tokens_cpu
.
to
(
device
=
"cpu"
)
num_experts
=
num_expert_tokens
.
size
(
0
)
for
e
in
range
(
num_experts
):
num_tokens
=
num_expert_tokens_cpu
[
e
]
C
[
e
,
:
num_tokens
,
:]
=
A
[
e
,
:
num_tokens
,
:]
@
B
[
e
].
transpose
(
0
,
1
)
return
C
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
64
,
128
,
192
,
224
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
):
config
=
BatchedMMConfig
(
dtype
,
num_experts
,
max_tokens_per_expert
,
K
,
N
)
tensors
=
BatchedMMTensors
.
make_tensors
(
config
)
test_output
=
tensors
.
C
ref_output
=
test_output
.
clone
()
compute_tl_dtype
=
{
torch
.
float16
:
tl
.
float16
,
torch
.
bfloat16
:
tl
.
bfloat16
,
torch
.
float32
:
tl
.
float32
}[
test_output
.
dtype
]
invoke_moe_batched_triton_kernel
(
tensors
.
A
,
tensors
.
B
,
test_output
,
tensors
.
num_expert_tokens
,
compute_tl_dtype
,
# Quantization data
None
,
None
,
None
,
# Quantization schemes
False
,
False
,
False
,
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
16
})
ref_output
=
ref_impl
(
tensors
.
A
,
tensors
.
B
,
ref_output
,
tensors
.
num_expert_tokens
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
torch
.
bfloat16
:
(
6e-2
,
6e-2
),
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
test_output
.
dtype
]
torch
.
testing
.
assert_close
(
test_output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
tests/kernels/moe/test_cutlass_moe.py
View file @
f9c069c8
...
@@ -30,6 +30,11 @@ MNK_FACTORS = [
...
@@ -30,6 +30,11 @@ MNK_FACTORS = [
(
224
,
3072
,
1536
),
(
224
,
3072
,
1536
),
]
]
vllm_config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MOETensors
:
class
MOETensors
:
...
@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
...
@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'w1_q'
:
moe_tensors
.
w1_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w1_q'
:
moe_tensors
.
w1_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w2_q'
:
moe_tensors
.
w2_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w2_q'
:
moe_tensors
.
w2_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'topk_weights'
:
topk_weights
,
'topk_weights'
:
topk_weights
,
'topk_ids
_
'
:
topk_ids
,
'topk_ids'
:
topk_ids
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
...
@@ -231,15 +236,12 @@ def test_cutlass_moe_8_bit_no_graph(
...
@@ -231,15 +236,12 @@ def test_cutlass_moe_8_bit_no_graph(
per_out_ch
:
bool
,
per_out_ch
:
bool
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_ch
)
per_out_ch
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
...
@@ -276,17 +278,14 @@ def test_cutlass_moe_8_bit_cuda_graph(
...
@@ -276,17 +278,14 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_out_ch
:
bool
,
per_out_ch
:
bool
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
dtype
=
torch
.
half
dtype
=
torch
.
half
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_ch
)
per_out_ch
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
...
@@ -334,15 +333,12 @@ def test_cutlass_moe_8_bit_EP(
...
@@ -334,15 +333,12 @@ def test_cutlass_moe_8_bit_EP(
ep_size
:
int
,
ep_size
:
int
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_channel
)
per_out_channel
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
...
...
tests/kernels/moe/test_moe.py
View file @
f9c069c8
...
@@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
...
@@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import
vllm.model_executor.layers.fused_moe
# noqa
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
...
@@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
...
@@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
EP_SIZE
=
[
1
,
4
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
2
,
6
]
TOP_KS
=
[
2
,
6
]
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
1024
*
128
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
1024
*
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
...
@@ -70,6 +75,7 @@ def test_fused_moe(
...
@@ -70,6 +75,7 @@ def test_fused_moe(
else
:
else
:
e_map
=
None
e_map
=
None
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
iterative_output
=
iterative_moe
(
a
,
iterative_output
=
iterative_moe
(
a
,
w1
,
w1
,
...
@@ -95,6 +101,7 @@ def test_fused_moe(
...
@@ -95,6 +101,7 @@ def test_fused_moe(
global_num_experts
=
e
,
global_num_experts
=
e
,
expert_map
=
e_map
,
expert_map
=
e_map
,
renormalize
=
False
)
renormalize
=
False
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
torch_output
,
...
@@ -115,7 +122,6 @@ def test_fused_moe(
...
@@ -115,7 +122,6 @@ def test_fused_moe(
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
weight_bits
:
int
):
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
@@ -194,6 +200,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -194,6 +200,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else
:
else
:
e_map
=
None
e_map
=
None
with
set_current_vllm_config
(
vllm_config
):
triton_output
=
fused_moe
(
a
,
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w1_qweight
,
w2_qweight
,
w2_qweight
,
...
@@ -210,6 +217,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -210,6 +217,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
...
@@ -515,6 +523,7 @@ def test_fused_marlin_moe(
...
@@ -515,6 +523,7 @@ def test_fused_marlin_moe(
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
...
...
tests/kernels/moe/test_pplx_moe.py
0 → 100644
View file @
f9c069c8
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import
dataclasses
import
os
import
traceback
from
typing
import
Callable
,
Optional
import
pytest
import
torch
try
:
from
pplx_kernels
import
AllToAll
from
pplx_kernels.nvshmem
import
(
nvshmem_alloc_empty_unique_id
,
nvshmem_finalize
,
nvshmem_get_unique_id
,
nvshmem_init
)
has_pplx
=
True
except
ImportError
:
has_pplx
=
False
from
torch.multiprocessing
import
(
spawn
)
# pyright: ignore[reportPrivateImportUsage]
from
typing_extensions
import
Concatenate
,
ParamSpec
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
override_config
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
BatchedExperts
,
BatchedPrepareAndFinalize
,
BatchedTritonExperts
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
get_default_config
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
from
vllm.platforms
import
current_platform
PPLX_PREPARE_COMBOS
=
[(
4
,
128
,
128
),
(
32
,
1024
,
512
),
(
64
,
1024
,
512
),
(
222
,
2048
,
1024
)]
PPLX_MOE_COMBOS
=
[
(
1
,
128
,
128
),
(
2
,
128
,
512
),
(
3
,
1024
,
2048
),
(
32
,
128
,
1024
),
(
45
,
512
,
2048
),
(
64
,
1024
,
1024
),
(
222
,
1024
,
2048
),
]
NUM_EXPERTS
=
[
8
,
64
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
1
,
2
,
6
]
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
P
=
ParamSpec
(
"P"
)
requires_pplx
=
pytest
.
mark
.
skipif
(
not
has_pplx
,
reason
=
"Requires PPLX kernels"
,
)
@
dataclasses
.
dataclass
class
ProcessGroupInfo
:
world_size
:
int
world_local_size
:
int
rank
:
int
node_rank
:
int
local_rank
:
int
device
:
torch
.
device
def
_worker_parallel_launch
(
local_rank
:
int
,
world_size
:
int
,
world_local_size
:
int
,
node_rank
:
int
,
init_method
:
str
,
worker
:
Callable
[
Concatenate
[
ProcessGroupInfo
,
P
],
None
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
,
)
->
None
:
rank
=
node_rank
*
world_local_size
+
local_rank
torch
.
cuda
.
set_device
(
local_rank
)
device
=
torch
.
device
(
"cuda"
,
local_rank
)
torch
.
distributed
.
init_process_group
(
backend
=
"cpu:gloo,cuda:nccl"
,
init_method
=
init_method
,
rank
=
rank
,
world_size
=
world_size
,
device_id
=
device
,
)
barrier
=
torch
.
tensor
([
rank
],
device
=
device
)
torch
.
distributed
.
all_reduce
(
barrier
)
try
:
worker
(
ProcessGroupInfo
(
world_size
=
world_size
,
world_local_size
=
world_local_size
,
rank
=
rank
,
node_rank
=
node_rank
,
local_rank
=
local_rank
,
device
=
device
,
),
*
args
,
**
kwargs
,
)
except
Exception
as
ex
:
print
(
ex
)
traceback
.
print_exc
()
raise
finally
:
torch
.
distributed
.
destroy_process_group
()
def
parallel_launch
(
world_size
:
int
,
worker
:
Callable
[
Concatenate
[
ProcessGroupInfo
,
P
],
None
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
,
)
->
None
:
assert
not
kwargs
spawn
(
_worker_parallel_launch
,
args
=
(
world_size
,
world_size
,
0
,
"tcp://localhost:29500"
,
worker
,
)
+
args
,
nprocs
=
world_size
,
join
=
True
,
)
def
parallel_launch_from_env
(
worker
:
Callable
[
Concatenate
[
ProcessGroupInfo
,
P
],
None
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
,
)
->
None
:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert
not
kwargs
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
world_local_size
=
int
(
os
.
environ
[
"WORLD_LOCAL_SIZE"
])
node_rank
=
int
(
os
.
environ
[
"NODE_RANK"
])
assert
"MASTER_ADDR"
in
os
.
environ
assert
"MASTER_PORT"
in
os
.
environ
spawn
(
_worker_parallel_launch
,
args
=
(
world_size
,
world_local_size
,
node_rank
,
"env://"
,
worker
,
)
+
args
,
nprocs
=
world_local_size
,
join
=
True
,
)
def
torch_prepare
(
a
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
max_num_tokens
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
topk_ids
.
dim
()
==
2
assert
topk_ids
.
shape
[
0
]
==
a
.
shape
[
0
]
num_tokens
,
hidden_dim
=
a
.
shape
topk
=
topk_ids
.
shape
[
1
]
tokens_per_expert
=
torch
.
bincount
(
topk_ids
.
view
(
-
1
),
minlength
=
num_experts
)
assert
tokens_per_expert
.
numel
()
==
num_experts
if
max_num_tokens
is
None
:
max_num_tokens
=
int
(
tokens_per_expert
.
max
().
item
())
b_a
=
torch
.
zeros
((
num_experts
,
max_num_tokens
,
hidden_dim
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
token_counts
=
torch
.
zeros
(
num_experts
,
dtype
=
torch
.
int
,
device
=
a
.
device
)
for
token
in
range
(
num_tokens
):
for
j
in
range
(
topk
):
expert_id
=
topk_ids
[
token
,
j
]
idx
=
token_counts
[
expert_id
]
b_a
[
expert_id
,
idx
:
idx
+
1
,
:]
=
a
[
token
,
:]
token_counts
[
expert_id
]
=
token_counts
[
expert_id
]
+
1
return
b_a
,
tokens_per_expert
def
torch_finalize
(
b_out
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
topk_ids
.
shape
[
0
]
num_experts
=
b_out
.
shape
[
0
]
K
=
b_out
.
shape
[
-
1
]
out
=
torch
.
zeros
((
num_tokens
,
K
),
dtype
=
b_out
.
dtype
,
device
=
b_out
.
device
)
expert_counts
=
torch
.
zeros
(
num_experts
,
dtype
=
torch
.
int
,
device
=
b_out
.
device
)
for
token
in
range
(
num_tokens
):
expert_ids
=
topk_ids
[
token
]
for
i
in
range
(
expert_ids
.
numel
()):
expert_id
=
expert_ids
[
i
]
idx
=
expert_counts
[
expert_id
]
out
[
token
,
:]
=
out
[
token
,
:]
+
b_out
[
expert_id
,
idx
:
idx
+
1
,
:]
*
topk_weight
[
token
,
i
]
expert_counts
[
expert_id
]
=
expert_counts
[
expert_id
]
+
1
return
out
def
torch_batched_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_experts
=
w1
.
shape
[
0
]
b_a
,
tokens_per_expert
=
torch_prepare
(
a
,
topk_ids
,
num_experts
)
assert
b_a
.
dim
()
==
3
num_tokens
,
topk
=
topk_ids
.
shape
_
,
max_num_tokens
,
K
=
b_a
.
shape
assert
num_experts
==
b_a
.
shape
[
0
]
and
w2
.
shape
[
1
]
==
K
out
=
torch
.
zeros
((
num_experts
,
max_num_tokens
,
K
),
dtype
=
b_a
.
dtype
,
device
=
b_a
.
device
)
tmp
=
torch
.
empty
((
max_num_tokens
,
w1
.
shape
[
1
]
//
2
),
dtype
=
b_a
.
dtype
,
device
=
b_a
.
device
)
for
expert
in
range
(
num_experts
):
num
=
tokens_per_expert
[
expert
]
if
num
>
0
:
torch
.
ops
.
_C
.
silu_and_mul
(
tmp
[:
num
],
b_a
[
expert
,
:
num
,
:]
@
w1
[
expert
].
transpose
(
0
,
1
))
out
[
expert
,
:
num
,
:]
=
tmp
[:
num
]
@
w2
[
expert
].
transpose
(
0
,
1
)
return
torch_finalize
(
out
,
topk_weight
,
topk_ids
)
def
batched_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_experts
=
w1
.
shape
[
0
]
fused_experts
=
FusedMoEModularKernel
(
BatchedPrepareAndFinalize
(
a
.
shape
[
0
],
world_size
=
1
,
dp_size
=
1
,
rank
=
0
),
BatchedExperts
(
max_num_tokens
=
a
.
shape
[
0
],
dp_size
=
1
,
world_size
=
1
))
return
fused_experts
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
num_experts
)
# Note: same as torch_moe but with fused_topk factored out.
def
torch_moe2
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
M
,
K
=
a
.
shape
topk
=
topk_ids
.
shape
[
1
]
a
=
a
.
view
(
M
,
-
1
,
K
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
K
)
out
=
torch
.
zeros
(
M
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
num_experts
=
w1
.
shape
[
0
]
for
i
in
range
(
num_experts
):
mask
=
(
topk_ids
==
i
).
view
(
-
1
)
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
M
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
M
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
def
test_fused_moe_batched_experts
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
):
current_platform
.
seed_everything
(
7
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
with
set_current_vllm_config
(
vllm_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
baseline_output
=
torch_moe2
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
torch_output
=
torch_batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
batched_output
=
batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
torch
.
testing
.
assert_close
(
baseline_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
baseline_output
,
batched_output
,
atol
=
2e-2
,
rtol
=
0
)
def
rank_chunk
(
num
:
int
,
r
:
int
,
w
:
int
)
->
int
:
rem
=
num
%
w
return
(
num
//
w
)
+
(
1
if
r
<
rem
else
0
)
def
chunk_by_rank
(
t
:
torch
.
Tensor
,
r
:
int
,
w
:
int
)
->
torch
.
Tensor
:
chunk
=
rank_chunk
(
t
.
shape
[
0
],
r
,
w
)
return
t
[(
r
*
chunk
):(
r
+
1
)
*
chunk
]
def
pplx_prepare_finalize
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
a
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
)
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
topk
=
topk_ids
.
shape
[
1
]
num_tokens
,
hidden_dim
=
a
.
shape
block_size
=
128
device
=
pgi
.
device
rank
=
pgi
.
rank
world_size
=
pgi
.
world_size
max_num_tokens
=
rank_chunk
(
num_tokens
,
0
,
world_size
)
ata
=
AllToAll
.
internode
(
max_num_tokens
=
max_num_tokens
,
num_experts
=
num_experts
,
experts_per_token
=
topk
,
rank
=
rank
,
world_size
=
world_size
,
dp_size
=
dp_size
,
hidden_dim
=
hidden_dim
,
hidden_dim_bytes
=
hidden_dim
*
a
.
dtype
.
itemsize
,
hidden_dim_scale_bytes
=
(
0
if
a
.
dtype
.
itemsize
!=
1
else
((
hidden_dim
+
block_size
-
1
)
//
block_size
*
torch
.
float32
.
itemsize
)),
)
topk_ids
=
topk_ids
.
to
(
dtype
=
torch
.
uint32
)
prepare_finalize
=
PplxPrepareAndFinalize
(
ata
,
max_num_tokens
,
world_size
,
rank
,
dp_size
,
a
.
dtype
,
)
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
).
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
).
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
).
to
(
device
)
b_a
,
b_a_scale
,
expert_num_tokens
=
prepare_finalize
.
prepare
(
a_chunk
,
None
,
None
,
chunk_topk_weight
,
chunk_topk_ids
,
num_experts
,
None
,
False
,
)
b_a
=
b_a
*
1.5
out
=
torch
.
full
(
(
max_num_tokens
,
hidden_dim
),
torch
.
nan
,
dtype
=
a
.
dtype
,
device
=
device
,
)
prepare_finalize
.
finalize
(
out
,
b_a
,
chunk_topk_weight
,
chunk_topk_ids
,
False
,
)
torch
.
cuda
.
synchronize
()
ata
.
destroy
()
num_tokens
=
a_chunk
.
shape
[
0
]
return
out
[:
num_tokens
]
def
_pplx_prepare_finalize
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
a
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
torch
.
Tensor
,
num_experts
:
int
,
):
uid
=
nvshmem_get_unique_id
(
)
if
pgi
.
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
torch
.
distributed
.
broadcast
(
uid
,
src
=
0
)
nvshmem_init
(
uid
,
pgi
.
rank
,
pgi
.
world_size
)
device
=
pgi
.
device
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
k
=
a
.
shape
[
1
]
a_rep
=
torch
.
repeat_interleave
(
a
,
topk
,
dim
=
0
).
to
(
device
)
torch_output
=
(
a_rep
.
view
(
-
1
,
topk
,
k
)
*
1.5
*
topk_weight
.
view
(
-
1
,
topk
,
1
).
to
(
device
)).
sum
(
dim
=
1
).
to
(
a
.
dtype
)
pplx_output
=
pplx_prepare_finalize
(
pgi
,
dp_size
,
a
,
topk_weight
,
topk_ids
,
num_experts
)
torch_output
=
chunk_by_rank
(
torch_output
,
pgi
.
rank
,
pgi
.
world_size
).
to
(
pplx_output
.
device
)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
nvshmem_finalize
()
# TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
@
pytest
.
mark
.
parametrize
(
"mnk"
,
PPLX_PREPARE_COMBOS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
requires_pplx
def
test_pplx_prepare_finalize
(
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
world_dp_size
:
tuple
[
int
,
int
],
):
current_platform
.
seed_everything
(
7
)
m
,
n
,
k
=
mnk
world_size
,
dp_size
=
world_dp_size
device
=
"cuda"
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
device
,
dtype
=
dtype
)
parallel_launch
(
world_size
,
_pplx_prepare_finalize
,
dp_size
,
a
,
score
,
topk
,
e
)
def
pplx_moe
(
rank
:
int
,
world_size
:
int
,
dp_size
:
int
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_compile
:
bool
=
True
,
use_cudagraphs
:
bool
=
True
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
)
device
=
torch
.
device
(
"cuda"
,
rank
)
hidden_dim
=
a
.
shape
[
1
]
num_experts
=
w1
.
shape
[
0
]
block_size
=
128
topk
=
topk_ids
.
shape
[
1
]
max_num_tokens
=
rank_chunk
(
a
.
shape
[
0
],
0
,
world_size
)
ata
=
AllToAll
.
internode
(
max_num_tokens
=
max_num_tokens
,
num_experts
=
num_experts
,
experts_per_token
=
topk
,
rank
=
rank
,
world_size
=
world_size
,
dp_size
=
dp_size
,
hidden_dim
=
hidden_dim
,
hidden_dim_bytes
=
hidden_dim
*
a
.
dtype
.
itemsize
,
hidden_dim_scale_bytes
=
(
0
if
a
.
dtype
.
itemsize
!=
1
else
((
hidden_dim
+
block_size
-
1
)
//
block_size
*
torch
.
float32
.
itemsize
)),
)
topk_ids
=
topk_ids
.
to
(
dtype
=
torch
.
uint32
)
prepare_finalize
=
PplxPrepareAndFinalize
(
ata
,
max_num_tokens
,
world_size
,
rank
,
dp_size
,
)
experts
=
BatchedTritonExperts
(
max_num_tokens
=
a
.
shape
[
0
],
world_size
=
world_size
,
dp_size
=
dp_size
)
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
).
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
).
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
).
to
(
device
)
# Chunking weights like this only works for batched format
w1_chunk
=
chunk_by_rank
(
w1
,
rank
,
world_size
).
to
(
device
)
w2_chunk
=
chunk_by_rank
(
w2
,
rank
,
world_size
).
to
(
device
)
if
use_compile
:
_fused_experts
=
torch
.
compile
(
fused_experts
,
backend
=
'inductor'
,
fullgraph
=
True
)
else
:
_fused_experts
=
fused_experts
out
=
_fused_experts
(
a_chunk
,
w1_chunk
,
w2_chunk
,
chunk_topk_weight
,
chunk_topk_ids
,
global_num_experts
=
num_experts
)
if
use_cudagraphs
:
out
.
fill_
(
0
)
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
out
=
_fused_experts
(
a_chunk
,
w1_chunk
,
w2_chunk
,
chunk_topk_weight
,
chunk_topk_ids
,
global_num_experts
=
num_experts
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
ata
.
destroy
()
return
out
def
_batched_moe
(
pgi
,
dp_size
,
a
,
w1
,
w2
,
topk_weight
,
topk_ids
):
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
num_experts
=
w1
.
shape
[
0
]
device
=
pgi
.
device
rank
=
pgi
.
rank
world_size
=
pgi
.
world_size
max_num_tokens
=
rank_chunk
(
a
.
shape
[
0
],
0
,
world_size
)
prepare_finalize
=
BatchedPrepareAndFinalize
(
max_num_tokens
=
max_num_tokens
,
world_size
=
world_size
,
dp_size
=
dp_size
,
rank
=
rank
,
)
experts
=
BatchedExperts
(
max_num_tokens
=
a
.
shape
[
0
],
world_size
=
1
,
dp_size
=
1
)
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
).
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
).
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
).
to
(
device
)
out
=
fused_experts
(
a_chunk
,
# Chunking weights like this only works for batched format
chunk_by_rank
(
w1
,
rank
,
world_size
).
to
(
device
),
chunk_by_rank
(
w2
,
rank
,
world_size
).
to
(
device
),
chunk_topk_weight
,
chunk_topk_ids
,
global_num_experts
=
num_experts
)
return
out
def
_pplx_moe
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
int
,
):
uid
=
nvshmem_get_unique_id
(
)
if
pgi
.
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
torch
.
distributed
.
broadcast
(
uid
,
src
=
0
)
nvshmem_init
(
uid
,
pgi
.
rank
,
pgi
.
world_size
)
m
,
k
=
a
.
shape
e
,
_
,
n
=
w2
.
shape
moe_config
=
get_default_config
(
m
,
e
,
n
,
k
,
topk
,
a
.
dtype
,
False
)
with
set_current_vllm_config
(
vllm_config
),
override_config
(
moe_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
torch_output
=
torch_moe2
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
pplx_output
=
pplx_moe
(
pgi
.
rank
,
pgi
.
world_size
,
dp_size
,
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
torch_output
=
chunk_by_rank
(
torch_output
,
pgi
.
rank
,
pgi
.
world_size
).
to
(
pplx_output
.
device
)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize
()
@
pytest
.
mark
.
parametrize
(
"mnk"
,
PPLX_MOE_COMBOS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
requires_pplx
def
test_pplx_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
world_dp_size
:
tuple
[
int
,
int
],
):
current_platform
.
seed_everything
(
7
)
m
,
n
,
k
=
mnk
world_size
,
dp_size
=
world_dp_size
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
parallel_launch
(
world_size
,
_pplx_moe
,
dp_size
,
a
,
w1
,
w2
,
score
,
topk
)
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
View file @
f9c069c8
...
@@ -7,6 +7,7 @@ import pytest
...
@@ -7,6 +7,7 @@ import pytest
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
...
@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
def
native_w8a8_per_token_matmul
(
A
,
B
,
As
,
Bs
,
output_dtype
=
torch
.
float16
):
def
native_w8a8_per_token_matmul
(
A
,
B
,
As
,
Bs
,
output_dtype
=
torch
.
float16
):
"""Matrix multiplication function that supports per-token input
"""Matrix multiplication function that supports per-token input
...
@@ -137,6 +142,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
...
@@ -137,6 +142,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w2_s
=
torch
.
rand
(
E
,
K
,
device
=
w2_fp32
.
device
)
*
factor_for_scale
w2_s
=
torch
.
rand
(
E
,
K
,
device
=
w2_fp32
.
device
)
*
factor_for_scale
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
with
set_current_vllm_config
(
vllm_config
):
ref_out
=
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
)
ref_out
=
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
)
out
=
fused_moe
(
out
=
fused_moe
(
a
,
a
,
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
f9c069c8
...
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
...
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
deep_gemm_moe_fp8
)
_valid_deep_gemm_shape
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
moe_align_block_size
)
...
@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
...
@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
# Test configurations
# Test configurations
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS
=
[
7
,
83
,
2048
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
...
@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
...
@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
# Set the context to avoid lots of warning spam.
# Set the context to avoid lots of warning spam.
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
out
=
fused_moe
(
a
,
a
,
...
@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
...
@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
"M,N,K,block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
def
test_w8a8_block_fp8_deep_gemm_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
# only aligned sizes
# only aligned sizes
...
@@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
...
@@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_size
=
[
block_m
,
block_m
]
block_size
=
[
block_m
,
block_m
]
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
# only aligned sizes
if
topk
>
E
:
if
(
N
%
block_m
!=
0
or
K
%
block_m
!=
0
or
topk
>
E
):
pytest
.
skip
(
f
"Skipping test: topk=
{
topk
}
> E=
{
E
}
"
)
pytest
.
skip
(
f
"Skipping test; bad size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
, topk=
{
topk
}
, E=
{
E
}
"
)
if
N
<=
512
:
pytest
.
skip
(
"Skipping N <= 512 until performance issues solved."
)
vllm_config
=
VllmConfig
()
if
not
_valid_deep_gemm_shape
(
M
,
N
,
K
):
pytest
.
skip
(
f
"Skipping test: invalid size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
"
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
...
...
tests/kernels/quantization/test_block_int8.py
View file @
f9c069c8
...
@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
...
@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
pytest
.
skip
(
"INT8 Triton requires CUDA 7.0 or higher"
,
pytest
.
skip
(
"INT8 Triton requires CUDA 7.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
# For test
# For test
def
native_per_token_group_quant_int8
(
x
,
def
native_per_token_group_quant_int8
(
x
,
...
@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
...
@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
# Set the context to avoid lots of warning spam.
# Set the context to avoid lots of warning spam.
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
out
=
fused_moe
(
a
,
a
,
...
...
vllm/distributed/parallel_state.py
View file @
f9c069c8
...
@@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
...
@@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
"""
"""
import
contextlib
import
contextlib
import
gc
import
gc
import
importlib.util
import
pickle
import
pickle
import
weakref
import
weakref
from
collections
import
namedtuple
from
collections
import
namedtuple
...
@@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
...
@@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
direct_register_custom_op
,
resolve_obj_by_qualname
,
from
vllm.utils
import
(
direct_register_custom_op
,
resolve_obj_by_qualname
,
supports_custom_op
)
run_once
,
supports_custom_op
)
@
dataclass
@
dataclass
...
@@ -936,9 +937,49 @@ def init_distributed_environment(
...
@@ -936,9 +937,49 @@ def init_distributed_environment(
"world group already initialized with a different world size"
)
"world group already initialized with a different world size"
)
PPLX_DID_INIT
:
bool
=
False
@
run_once
def
pplx_init
(
rank
,
world_size
):
has_pplx
=
importlib
.
util
.
find_spec
(
"pplx_kernels"
)
is
not
None
if
has_pplx
and
world_size
>
1
:
from
pplx_kernels.nvshmem
import
(
nvshmem_alloc_empty_unique_id
,
nvshmem_get_unique_id
,
nvshmem_init
)
try
:
global
PPLX_DID_INIT
logger
.
debug
(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d"
,
rank
,
world_size
)
uid
=
nvshmem_get_unique_id
(
)
if
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
uid_gpu
=
uid
.
cuda
()
get_world_group
().
broadcast
(
uid_gpu
,
src
=
0
)
uid
=
uid_gpu
.
to
(
device
=
'cpu'
)
logger
.
debug
(
"PPLX NVSHMEM UID = %s"
,
uid
)
nvshmem_init
(
uid
,
rank
,
world_size
)
PPLX_DID_INIT
=
True
except
Exception
as
ex
:
logger
.
error
(
"Failed to initialize NVSHMEM for PPLX: %s"
,
ex
)
@
run_once
def
pplx_finalize
():
global
PPLX_DID_INIT
if
PPLX_DID_INIT
:
from
pplx_kernels.nvshmem
import
nvshmem_finalize
logger
.
debug
(
"PPLX NVSHMEM finalize"
)
from
vllm.model_executor.layers.fused_moe.layer
import
(
_all_to_all_cache
)
_all_to_all_cache
.
destroy
()
nvshmem_finalize
()
def
initialize_model_parallel
(
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
enable_expert_parallel
:
bool
=
False
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -1041,10 +1082,14 @@ def initialize_model_parallel(
...
@@ -1041,10 +1082,14 @@ def initialize_model_parallel(
_DP
.
rank_in_group
,
_PP
.
rank_in_group
,
_TP
.
rank_in_group
,
_DP
.
rank_in_group
,
_PP
.
rank_in_group
,
_TP
.
rank_in_group
,
_EP
.
rank_in_group
)
_EP
.
rank_in_group
)
if
enable_expert_parallel
:
pplx_init
(
rank
,
world_size
)
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
tensor_model_parallel_size
:
int
,
tensor_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
enable_expert_parallel
:
bool
=
False
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""Helper to initialize model parallel groups if they are not initialized,
"""Helper to initialize model parallel groups if they are not initialized,
...
@@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
...
@@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
get_world_group
().
device_group
)
get_world_group
().
device_group
)
if
not
model_parallel_is_initialized
():
if
not
model_parallel_is_initialized
():
initialize_model_parallel
(
tensor_model_parallel_size
,
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
,
backend
)
pipeline_model_parallel_size
,
enable_expert_parallel
,
backend
)
return
return
assert
(
assert
(
...
@@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
...
@@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none and destroy them."""
"""Set the groups to none and destroy them."""
global
_TP
global
_TP
pplx_finalize
()
if
_TP
:
if
_TP
:
_TP
.
destroy
()
_TP
.
destroy
()
_TP
=
None
_TP
=
None
...
...
vllm/distributed/utils.py
View file @
f9c069c8
...
@@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
...
@@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_tcp_uri
from
vllm.utils
import
get_tcp_uri
,
is_torch_equal_or_newer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
...
@@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
Destroy ProcessGroup returned by
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
stateless_init_torch_distributed_process_group().
"""
"""
if
is_torch_equal_or_newer
(
"2.7"
):
pg
.
shutdown
()
else
:
# Lazy import for non-CUDA backends.
# Lazy import for non-CUDA backends.
try
:
# pytorch <= 2.6
from
torch.distributed.distributed_c10d
import
_shutdown_backend
from
torch.distributed.distributed_c10d
import
_shutdown_backend
_shutdown_backend
(
pg
)
_shutdown_backend
(
pg
)
except
ImportError
:
# pytorch >= 2.7
pg
.
shutdown
()
_unregister_process_group
(
pg
.
group_name
)
_unregister_process_group
(
pg
.
group_name
)
vllm/forward_context.py
View file @
f9c069c8
...
@@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
...
@@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@
dataclass
@
dataclass
class
DPMetadata
:
class
DPMetadata
:
max_tokens_across_dp_cpu
:
torch
.
Tensor
cu_tokens_across_dp_cpu
:
torch
.
Tensor
cu_tokens_across_dp_cpu
:
torch
.
Tensor
...
@@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
...
@@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.distributed.parallel_state
import
get_dp_group
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
get_dp_group
().
cpu_group
)
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
get_dp_group
().
cpu_group
)
max_tokens_across_dp_cpu
=
torch
.
max
(
num_tokens_tensor
)
cu_tokens_across_dp_cpu
=
torch
.
cumsum
(
num_tokens_tensor
,
dim
=
0
)
cu_tokens_across_dp_cpu
=
torch
.
cumsum
(
num_tokens_tensor
,
dim
=
0
)
dp_metadata
=
DPMetadata
(
cu_tokens_across_dp_cpu
)
dp_metadata
=
DPMetadata
(
max_tokens_across_dp_cpu
,
cu_tokens_across_dp_cpu
)
global
_forward_context
global
_forward_context
prev_context
=
_forward_context
prev_context
=
_forward_context
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
f9c069c8
...
@@ -38,8 +38,8 @@ if HAS_TRITON:
...
@@ -38,8 +38,8 @@ if HAS_TRITON:
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp4
,
cutlass_moe_fp8
)
cutlass_moe_fp4
,
cutlass_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
TritonExperts
,
fused_experts
,
fused_moe
,
fused_topk
,
grouped_topk
)
get_config_file_name
,
grouped_topk
)
__all__
+=
[
__all__
+=
[
"fused_moe"
,
"fused_moe"
,
...
@@ -49,4 +49,5 @@ if HAS_TRITON:
...
@@ -49,4 +49,5 @@ if HAS_TRITON:
"grouped_topk"
,
"grouped_topk"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp4"
,
"cutlass_moe_fp4"
,
"TritonExperts"
,
]
]
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
f9c069c8
...
@@ -5,10 +5,176 @@ from typing import Optional
...
@@ -5,10 +5,176 @@ from typing import Optional
import
torch
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_perm
,
_resize_cache
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
class
CutlassExpertsFp8
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
ab_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
):
super
().
__init__
()
self
.
ab_strides1
=
ab_strides1
self
.
c_strides1
=
c_strides1
self
.
ab_strides2
=
ab_strides2
self
.
c_strides2
=
c_strides2
self
.
out_dtype
=
out_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
num_experts
:
int
,
)
->
tuple
[
int
,
int
,
torch
.
dtype
]:
# Note that K, N are transposed
N
,
K
=
K
,
N
workspace1
=
M
*
topk
*
max
(
2
*
N
,
K
)
workspace2
=
M
*
topk
*
N
return
(
workspace1
,
workspace2
,
self
.
out_dtype
)
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
a1q
=
hidden_states
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
assert
a1q
.
shape
[
1
]
==
w1
.
shape
[
1
],
"Hidden size mismatch w1"
assert
w1
.
shape
[
2
]
==
w2
.
shape
[
1
]
*
2
,
"Hidden size mismatch w2"
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Expert number mismatch"
assert
a1q_scale
is
None
or
a1q_scale
.
dim
(
)
==
0
or
a1q_scale
.
shape
[
0
]
==
1
or
a1q_scale
.
shape
[
0
]
==
a1q
.
shape
[
0
],
"Input scale shape mismatch"
assert
w1_scale
.
dim
()
==
1
or
w1_scale
.
shape
[
1
]
==
1
or
w1_scale
.
shape
[
1
]
==
w1
.
shape
[
2
],
"W1 scale shape mismatch"
assert
w2_scale
.
dim
()
==
1
or
w2_scale
.
shape
[
1
]
==
1
or
w2_scale
.
shape
[
1
]
==
w2
.
shape
[
2
],
"W2 scale shape mismatch"
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Weights expert number mismatch"
assert
w1
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a2_scale
is
None
or
a1q_scale
is
None
or
a2_scale
.
shape
==
a1q_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
assert
self
.
ab_strides1
.
shape
[
0
]
==
w1
.
shape
[
0
],
"AB Strides 1 expert number mismatch"
assert
self
.
c_strides1
.
shape
[
0
]
==
w1
.
shape
[
0
],
"C Strides 1 expert number mismatch"
assert
self
.
ab_strides2
.
shape
[
0
]
==
w2
.
shape
[
0
],
"AB Strides 2 expert number mismatch"
assert
self
.
c_strides2
.
shape
[
0
]
==
w2
.
shape
[
0
],
"C Strides 2 expert number mismatch"
assert
self
.
out_dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
M
=
a1q
.
shape
[
0
]
_
,
N
,
K
=
w2
.
shape
# because w1 + w2 are transposed
device
=
a1q
.
device
assert
w1
.
shape
[
1
]
==
K
assert
global_num_experts
!=
-
1
assert
a1q_scale
is
not
None
if
expert_map
is
not
None
:
"Translate info from expert_map to topk_ids"
local_topk_ids
=
torch
.
where
(
expert_map
[
topk_ids
]
!=
-
1
,
expert_map
[
topk_ids
],
-
1
)
else
:
local_topk_ids
=
topk_ids
topk
=
local_topk_ids
.
shape
[
1
]
per_act_token
=
a1q_scale
.
numel
()
!=
1
if
a1q_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
expert_offsets
=
torch
.
empty
((
global_num_experts
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes1
=
torch
.
empty
((
global_num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes2
=
torch
.
empty
((
global_num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
if
expert_map
is
not
None
:
a_map
=
torch
.
zeros
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
else
:
a_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
ops
.
get_cutlass_moe_mm_data
(
local_topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
global_num_experts
,
N
,
K
)
a1q
=
_fp8_perm
(
a1q
,
a_map
)
a1q_scale
=
a1q_scale
[
a_map
]
if
per_act_token
else
a1q_scale
c1
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
N
*
2
))
c2
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
N
))
c3
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
K
))
ops
.
cutlass_moe_mm
(
c1
,
a1q
,
w1
,
a1q_scale
,
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
self
.
ab_strides1
,
self
.
ab_strides1
,
self
.
c_strides1
)
self
.
activation
(
activation
,
c2
,
c1
)
a2q
,
a2q_scale
=
ops
.
scaled_fp8_quant
(
c2
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
if
expert_map
is
not
None
:
c3
.
fill_
(
0
)
ops
.
cutlass_moe_mm
(
c3
,
a2q
,
w2
,
a2q_scale
,
w2_scale
,
expert_offsets
[:
-
1
],
problem_sizes2
,
self
.
ab_strides2
,
self
.
ab_strides2
,
self
.
c_strides2
)
c3
=
c3
[
c_map
]
return
c3
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def
cutlass_moe_fp8
(
def
cutlass_moe_fp8
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
...
@@ -17,7 +183,7 @@ def cutlass_moe_fp8(
...
@@ -17,7 +183,7 @@ def cutlass_moe_fp8(
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
_
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
...
@@ -59,7 +225,7 @@ def cutlass_moe_fp8(
...
@@ -59,7 +225,7 @@ def cutlass_moe_fp8(
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Shape: scalar or [M]
- out_dtype (torch.
Tensor
): The output tensor type.
- out_dtype (torch.
dtype
): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
mapping from global expert-id to local expert-id. When expert_map[i]
...
@@ -71,115 +237,36 @@ def cutlass_moe_fp8(
...
@@ -71,115 +237,36 @@ def cutlass_moe_fp8(
Returns:
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
"""
assert
topk_weights
.
shape
==
topk_ids_
.
shape
,
"topk shape mismatch"
assert
w1_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_q
.
dtype
==
torch
.
float8_e4m3fn
assert
a
.
shape
[
1
]
==
w1_q
.
shape
[
1
],
"Hidden size mismatch w1"
assert
w1_q
.
shape
[
2
]
==
w2_q
.
shape
[
1
]
*
2
,
"Hidden size mismatch w2"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
a
.
shape
[
0
],
"Input scale shape mismatch"
assert
w1_scale
.
dim
()
==
1
or
w1_scale
.
shape
[
1
]
==
1
or
w1_scale
.
shape
[
1
]
==
w1_q
.
shape
[
2
],
"W1 scale shape mismatch"
assert
w2_scale
.
dim
()
==
1
or
w2_scale
.
shape
[
1
]
==
1
or
w2_scale
.
shape
[
1
]
==
w2_q
.
shape
[
2
],
"W2 scale shape mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Weights expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
assert
ab_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"AB Strides 1 expert number mismatch"
assert
c_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"C Strides 1 expert number mismatch"
assert
ab_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"AB Strides 2 expert number mismatch"
assert
c_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"C Strides 2 expert number mismatch"
assert
out_dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
k
=
w1_q
.
size
(
1
)
n
=
w2_q
.
size
(
1
)
local_topk_ids
=
topk_ids_
if
expert_map
is
not
None
:
"Translate info from expert_map to topk_ids"
local_topk_ids
=
torch
.
where
(
expert_map
[
topk_ids_
]
!=
-
1
,
expert_map
[
topk_ids_
],
-
1
)
topk
=
local_topk_ids
.
size
(
1
)
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
if
apply_router_weight_on_input
:
assert
topk
==
1
,
\
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a
=
a
*
topk_weights
.
to
(
out_dtype
)
a_q
,
a1_scale
=
ops
.
scaled_fp8_quant
(
a
,
a1_scale
,
use_per_token_if_dynamic
=
per_act_token
)
device
=
a_q
.
device
expert_offsets
=
torch
.
empty
((
num_experts
+
1
),
fn
=
mk
.
FusedMoEModularKernel
(
dtype
=
torch
.
int32
,
MoEPrepareAndFinalizeNoEP
(
device
=
device
)
per_channel_quant
=
per_act_token
,
problem_sizes1
=
torch
.
empty
((
num_experts
,
3
),
quant_dtype
=
torch
.
float8_e4m3fn
,
dtype
=
torch
.
int32
,
),
device
=
device
)
CutlassExpertsFp8
(
problem_sizes2
=
torch
.
empty
((
num_experts
,
3
),
ab_strides1
,
dtype
=
torch
.
int32
,
c_strides1
,
device
=
device
)
ab_strides2
,
c_strides2
,
a_map_initializer
=
torch
.
empty
out_dtype
,
c2_initializer
=
torch
.
empty
),
if
expert_map
is
not
None
:
)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
return
fn
(
# zeros for correctness.
a
,
a_map_initializer
=
torch
.
zeros
w1_q
,
c2_initializer
=
torch
.
zeros
w2_q
,
topk_weights
,
a_map
=
a_map_initializer
((
local_topk_ids
.
numel
()),
topk_ids
,
dtype
=
torch
.
int32
,
expert_map
=
expert_map
,
device
=
device
)
w1_scale
=
w1_scale
,
c_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
w2_scale
=
w2_scale
,
dtype
=
torch
.
int32
,
a1_scale
=
a1_scale
,
device
=
device
)
a2_scale
=
a2_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
ops
.
get_cutlass_moe_mm_data
(
local_topk_ids
,
expert_offsets
,
problem_sizes1
,
)
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
k
)
rep_a_q
=
a_q
.
view
(
dtype
=
torch
.
uint8
)[
a_map
].
view
(
dtype
=
a_q
.
dtype
)
rep_a1_scales
=
a1_scale
[
a_map
]
if
per_act_token
else
a1_scale
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
c2_initializer
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
ops
.
cutlass_moe_mm
(
c1
,
rep_a_q
,
w1_q
,
rep_a1_scales
,
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
)
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
out_dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate
,
c1
)
intemediate_q
,
a2_scale
=
ops
.
scaled_fp8_quant
(
intermediate
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
ops
.
cutlass_moe_mm
(
c2
,
intemediate_q
,
w2_q
,
a2_scale
,
w2_scale
,
expert_offsets
[:
-
1
],
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
)
# Gather tokens
c2
=
c2
[
c_map
].
view
(
m
,
topk
,
k
)
if
not
apply_router_weight_on_input
:
c2
=
c2
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)
return
c2
.
sum
(
dim
=
1
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
f9c069c8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
functools
import
importlib.util
import
importlib.util
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_align_block_size
)
_moe_permute
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_perm
,
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
_fp8_quantize
,
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_quantize
,
_resize_cache
)
_resize_cache
)
from
vllm.utils
import
round_up
from
vllm.utils
import
round_up
...
@@ -19,6 +20,19 @@ logger = init_logger(__name__)
...
@@ -19,6 +20,19 @@ logger = init_logger(__name__)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
@
functools
.
cache
def
deep_gemm_block_shape
()
->
list
[
int
]:
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
block
=
dg
.
get_m_alignment_for_contiguous_layout
()
return
[
block
,
block
]
def
_valid_deep_gemm_shape
(
M
:
int
,
N
:
int
,
K
:
int
):
align
=
deep_gemm_block_shape
()[
0
]
return
align
<=
M
and
N
%
align
==
0
and
K
%
align
==
0
def
_valid_deep_gemm
(
hidden_states
:
torch
.
Tensor
,
def
_valid_deep_gemm
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
...
@@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
"""
if
not
has_deep_gemm
:
if
not
has_deep_gemm
:
logger
.
debug
(
"DeepGemm disabled: deep_gemm not available."
)
return
False
return
False
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
# Expert maps not supported yet.
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
logger
.
debug
(
"DeepGemm disabled: expert map NYI."
)
return
False
return
False
align
=
dg
.
get_m_alignment_for_contiguous_layout
()
M
=
hidden_states
.
size
(
0
)
M
=
hidden_states
.
shape
[
0
]
_
,
K
,
N
=
w2
.
size
()
_
,
K
,
N
=
w2
.
shape
if
not
_valid_deep_gemm_shape
(
M
,
N
,
K
):
logger
.
debug
(
"DeepGemm disabled: unalinged problem size."
)
return
False
# For now, disable DeepGemm for small N until better permute/unpermute
if
(
w1
.
dtype
!=
torch
.
float8_e4m3fn
or
w2
.
dtype
!=
torch
.
float8_e4m3fn
):
# ops are available.
logger
.
debug
(
"DeepGemm disabled: invalid weight dtype(s)."
)
if
N
<=
512
:
return
False
return
False
if
align
>
M
or
N
%
align
!=
0
or
K
%
align
!=
0
:
if
(
not
hidden_states
.
is_contiguous
()
or
not
w1
.
is_contiguous
()
or
not
w2
.
is_contiguous
()):
logger
.
debug
(
"DeepGemm disabled: weights or activations not contiguous."
)
return
False
return
False
return
(
hidden_states
.
is_contiguous
()
and
w1
.
is_contiguous
()
return
True
and
w2
.
is_contiguous
())
def
_moe_permute
(
class
DeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
curr_hidden_states
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
],
def
__init__
(
self
):
curr_topk_ids
:
torch
.
Tensor
,
super
().
__init__
()
self
.
block_shape
=
deep_gemm_block_shape
()
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
num_experts
:
int
,
)
->
tuple
[
int
,
int
,
torch
.
dtype
]:
block_m
=
self
.
block_shape
[
0
]
M_sum
=
(
M
*
topk
)
+
num_experts
*
(
block_m
-
1
)
M_sum
=
round_up
(
M_sum
,
block_m
)
workspace1
=
M_sum
*
max
(
N
*
2
,
K
)
workspace2
=
M_sum
*
N
return
(
workspace1
,
workspace2
,
a
.
dtype
)
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
block_m
:
int
,
w1_scale
:
Optional
[
torch
.
Tensor
],
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
w2_scale
:
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
w1_zp
:
Optional
[
torch
.
Tensor
],
"""
w2_zp
:
Optional
[
torch
.
Tensor
],
Determine the sorted_token_ids, expert_ids for the given problem size.
a1q_scale
:
Optional
[
torch
.
Tensor
],
Permute the hidden states and scales according to `sorted_token_ids`.
a2_scale
:
Optional
[
torch
.
Tensor
],
"""
workspace13
:
torch
.
Tensor
,
top_k_num
=
curr_topk_ids
.
shape
[
1
]
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
import
deep_gemm
as
dg
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
a1q
=
hidden_states
_
,
N
,
K
=
w1
.
size
()
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
assert
global_num_experts
!=
-
1
moe_align_block_size
(
curr_topk_ids
,
assert
w2
.
size
(
1
)
==
K
block_m
,
a1q
,
a1q_scale
,
_
,
expert_ids
,
inv_perm
=
_moe_permute
(
a1q
,
a1q_scale
,
topk_ids
,
global_num_experts
,
global_num_experts
,
expert_map
,
expert_map
,
pad_sorted_ids
=
True
))
self
.
block_shape
[
0
],
)
inv_perm
:
Optional
[
torch
.
Tensor
]
=
None
# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum
=
a1q
.
size
(
0
)
workspace1
=
_resize_cache
(
workspace13
,
(
M_sum
,
N
))
workspace2
=
_resize_cache
(
workspace2
,
(
M_sum
,
N
//
2
))
workspace3
=
_resize_cache
(
workspace13
,
(
M_sum
,
K
))
num_tokens
=
top_k_num
*
tokens_in_chunk
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
sorted_token_ids
=
sorted_token_ids
.
clamp
(
max
=
num_tokens
-
1
)
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_ids
)
expert_ids
=
torch
.
repeat_interleave
(
expert_ids
,
block_m
,
dim
=
0
)
inv_perm
=
torch
.
argsort
(
sorted_token_ids
)[:
num_tokens
]
# Permute according to sorted token ids.
self
.
activation
(
activation
,
workspace2
,
workspace1
.
view
(
-
1
,
N
))
curr_hidden_states
=
_fp8_perm
(
curr_hidden_states
,
sorted_token_ids
//
top_k_num
)
if
a1q_scale
is
not
None
:
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
a1q_scale
=
a1q_scale
[
sorted_token_ids
//
top_k_num
]
return
(
curr_hidden_states
,
a
1
q_scale
,
sorted_token_ids
,
expert_ids
,
a2q
,
a
2
q_scale
=
_fp8_quantize
(
workspace2
,
a2_scale
,
False
,
inv_perm
)
self
.
block_shape
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
workspace3
,
expert_ids
)
def
_moe_unpermute_and_reduce
(
workspace3
=
workspace3
[
inv_perm
,
...]
out
:
torch
.
Tensor
,
curr_hidden
:
torch
.
Tensor
,
return
workspace3
inv_perm
:
Optional
[
torch
.
Tensor
],
topk_weight
:
torch
.
Tensor
,
)
->
None
:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M
,
topk
=
topk_weight
.
shape
K
=
curr_hidden
.
shape
[
1
]
curr_hidden
=
curr_hidden
[
inv_perm
,
...]
curr_hidden
=
curr_hidden
.
view
(
-
1
,
topk
,
K
)
curr_hidden
.
mul_
(
topk_weight
.
view
(
M
,
-
1
,
1
))
ops
.
moe_sum
(
curr_hidden
,
out
)
def
deep_gemm_moe_fp8
(
def
deep_gemm_moe_fp8
(
...
@@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
...
@@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
...
@@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
...
@@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
Returns:
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
"""
# Lazy import to avoid CUDA initialization problems.
fn
=
mk
.
FusedMoEModularKernel
(
import
deep_gemm
as
dg
MoEPrepareAndFinalizeNoEP
(
quant_dtype
=
torch
.
float8_e4m3fn
,
block_shape
=
deep_gemm_block_shape
()),
assert
expert_map
is
None
,
"Expert maps not supported yet"
DeepGemmExperts
(),
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
return
fn
(
hidden_states
,
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
w1
,
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
w2
,
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
topk_weights
,
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
topk_ids
,
assert
hidden_states
.
dtype
in
[
inplace
,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
activation
,
]
global_num_experts
,
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
expert_map
,
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
w1_scale
=
w1_scale
,
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Expert number mismatch"
w2_scale
=
w2_scale
,
assert
w1
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
a1_scale
=
a1_scale
,
assert
w1
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
a2_scale
=
a2_scale
,
assert
a1_scale
is
None
or
a1_scale
.
dim
(
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
)
0
]
==
hidden_states
.
shape
[
0
],
"Input scale shape mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
assert
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
block_m
=
dg
.
get_m_alignment_for_contiguous_layout
()
block_shape
=
[
block_m
,
block_m
]
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w1_scale
).
contiguous
()
w2_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w2_scale
).
contiguous
()
M_sum
=
topk_ids
.
numel
()
+
global_num_experts
*
(
block_m
-
1
)
M_sum
=
round_up
(
M_sum
,
block_m
)
num_chunks
=
(
num_tokens
//
CHUNK_SIZE
)
+
1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13
=
torch
.
empty
(
M_sum
*
max
(
N
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
workspace1
=
workspace13
[:
M_sum
*
N
].
view
(
M_sum
,
N
)
workspace2
=
torch
.
empty
((
M_sum
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
workspace3
=
workspace13
[:
M_sum
*
K
].
view
(
M_sum
,
K
)
for
chunk
in
range
(
num_chunks
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
(
qcurr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
_moe_permute
(
qcurr_hidden_states
,
a1q_scale
,
curr_topk_ids
,
global_num_experts
,
expert_map
,
block_m
)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
curr_M
=
sorted_token_ids
.
numel
()
workspace1
=
_resize_cache
(
workspace1
,
(
curr_M
,
N
))
workspace2
=
_resize_cache
(
workspace2
,
(
curr_M
,
N
//
2
))
workspace3
=
_resize_cache
(
workspace3
,
(
curr_M
,
K
))
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qcurr_hidden_states
,
a1q_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_ids
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
workspace2
,
workspace1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
workspace2
,
workspace1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qworkspace2
,
a2q_scale
=
_fp8_quantize
(
workspace2
,
a2_scale
,
block_shape
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qworkspace2
,
a2q_scale
),
(
w2
,
w2_scale
),
workspace3
,
expert_ids
)
_moe_unpermute_and_reduce
(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
workspace3
.
view
(
*
workspace3
.
shape
),
inv_perm
,
curr_topk_weights
)
return
out_hidden_states
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
0 → 100644
View file @
f9c069c8
# SPDX-License-Identifier: Apache-2.0
"""Fused batched MoE kernel."""
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
try_get_optimal_moe_config
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
@
triton
.
jit
def
moe_mmk
(
a_ptrs
,
b_ptrs
,
K
,
expert_id
,
a_scale_ptr
,
b_scale_ptr
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak
,
stride_bk
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Offsets and masks
offs_m
,
offs_n
,
mask_m
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_w8a8
:
tl
.
constexpr
,
use_w8a16
:
tl
.
constexpr
):
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
if
use_w8a16
:
b_scale_ptrs
=
b_scale_ptr
+
expert_id
*
stride_bse
+
offs_n
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_w8a8
:
# block-wise
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
offs_m
*
stride_asm
offs_bsn
=
offs_n
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
expert_id
*
stride_bse
+
offs_bsn
*
stride_bsn
)
# tensor-wise
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
expert_id
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
mask
=
mask_m
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_K
offs_ks
=
k_start
//
group_k
a_scale
=
tl
.
load
(
a_scale_ptrs
+
offs_ks
*
stride_ask
,
mask
=
mask_m
,
other
=
0.0
)
b_scale
=
tl
.
load
(
b_scale_ptrs
+
offs_ks
*
stride_bsk
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
if
use_w8a8
:
# acc used to enable fp8_fast_accum
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_K
*
stride_ak
b_ptrs
+=
BLOCK_K
*
stride_bk
if
use_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
return
accumulator
@
triton
.
jit
def
expert_triton_kernel
(
a_ptr
,
#[max_tokens, K]
b_ptr
,
#[K, N]
c_ptr
,
#[max_tokens, N]
expert_id
,
compute_type
:
tl
.
constexpr
,
# Dimensions
M
,
N
,
K
,
# Quantization data
a_scale_ptr
,
b_scale_ptr
,
b_zp_ptr
,
# strides
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Blockwise quantization data
group_n
,
group_k
,
# Quantization schemes
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
# Kernel config
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
):
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
mask_m
=
offs_m
<
M
a_ptrs
=
a_ptr
+
offs_m
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
b_ptrs
=
b_ptr
+
offs_k
[:,
None
]
*
stride_bk
+
offs_n
[
None
,
:]
*
stride_bn
accumulator
=
moe_mmk
(
a_ptrs
,
b_ptrs
,
K
,
expert_id
,
a_scale_ptr
,
b_scale_ptr
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak
,
stride_bk
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Offsets and masks
offs_m
,
offs_n
,
mask_m
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Meta-parameters
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
compute_type
,
use_fp8_w8a8
,
use_int8_w8a16
)
# store in C
offs_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
c_ptrs
=
c_ptr
+
offs_m
[:,
None
]
*
stride_cm
+
offs_cn
[
None
,
:]
*
stride_cn
c_mask
=
mask_m
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
def
batched_triton_kernel
(
a_ptr
,
# [E, max_num_tokens, K]
b_ptr
,
# [E, K, N]
c_ptr
,
# [E, max_num_tokens, N]
expert_num_tokens
,
# [E]
compute_type
:
tl
.
constexpr
,
# Dimensions
max_num_tokens
,
K
,
N
,
# Quantization data
a_scale_ptr
,
b_scale_ptr
,
b_zp_ptr
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ae
,
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_ce
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Blockwise quantization data
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Quantization schemes
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
# Kernel config
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
):
expert_id
=
tl
.
program_id
(
axis
=
0
)
e_num_tokens
=
tl
.
load
(
expert_num_tokens
+
expert_id
)
if
e_num_tokens
==
0
:
# Early exit
return
pid_mn
=
tl
.
program_id
(
axis
=
1
)
#num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
pid_m
=
pid_mn
//
num_pid_n
pid_n
=
pid_mn
%
num_pid_n
cta_m_start
=
pid_m
*
BLOCK_M
cta_n_start
=
pid_n
*
BLOCK_N
if
cta_m_start
>=
e_num_tokens
:
# Early exit
return
cta_m_size
=
min
(
BLOCK_M
,
e_num_tokens
-
cta_m_start
)
cta_n_size
=
min
(
BLOCK_N
,
N
-
cta_n_start
)
a_ptr
=
a_ptr
+
expert_id
*
stride_ae
+
cta_m_start
*
stride_am
b_ptr
=
b_ptr
+
expert_id
*
stride_be
+
cta_n_start
*
stride_bn
c_ptr
=
(
c_ptr
+
expert_id
*
stride_ce
+
cta_m_start
*
stride_cm
+
cta_n_start
*
stride_cn
)
expert_triton_kernel
(
a_ptr
,
b_ptr
,
c_ptr
,
expert_id
,
compute_type
,
cta_m_size
,
# M
cta_n_size
,
# N
K
,
# K
a_scale_ptr
,
b_scale_ptr
,
b_zp_ptr
,
# Strides
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Blockwise quantization data
group_n
,
group_k
,
# Quantization schemes
use_fp8_w8a8
,
use_int8_w8a16
,
# Kernel config
BLOCK_M
,
BLOCK_N
,
BLOCK_K
)
def
invoke_moe_batched_triton_kernel
(
A
:
torch
.
Tensor
,
# [E, max_tokens, K]
B
:
torch
.
Tensor
,
# [E, K, N]
C
:
torch
.
Tensor
,
# [E, max_tokens, N]
expert_num_tokens
:
torch
.
Tensor
,
# [E]
compute_type
:
tl
.
dtype
,
# Quantization data
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
B_zp
:
torch
.
Tensor
,
# Quantization schemes
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
config
:
dict
[
str
,
int
],
block_shape
:
Optional
[
list
[
int
]]
=
None
):
assert
not
use_int4_w4a16
max_num_tokens
=
A
.
size
(
1
)
K
=
A
.
size
(
2
)
N
=
C
.
size
(
2
)
BLOCK_M
=
config
[
'BLOCK_SIZE_M'
]
BLOCK_N
=
config
[
'BLOCK_SIZE_N'
]
BLOCK_K
=
config
[
'BLOCK_SIZE_K'
]
assert
(
torch
.
compiler
.
is_compiling
()
or
torch
.
cuda
.
is_current_stream_capturing
()
or
max_num_tokens
%
BLOCK_M
==
0
)
grid
=
(
expert_num_tokens
.
size
(
0
),
triton
.
cdiv
(
max_num_tokens
,
BLOCK_M
)
*
triton
.
cdiv
(
B
.
size
(
1
),
BLOCK_N
))
batched_triton_kernel
[
grid
](
A
,
B
,
C
,
expert_num_tokens
,
compute_type
,
# Dimensions
max_num_tokens
,
K
,
N
,
# Quantization data
A_scale
,
B_scale
,
B_zp
,
# Strides
A
.
stride
(
0
),
A
.
stride
(
1
),
A
.
stride
(
2
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
0
),
C
.
stride
(
1
),
C
.
stride
(
2
),
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
A_scale
.
stride
(
1
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_scale
.
stride
(
2
)
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
# Blockwise quantization data
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
# Quantization schemes
use_fp8_w8a8
,
use_int8_w8a16
,
# Kernel config
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_K
=
BLOCK_K
)
class
BatchedPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
"""
A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format
that the PPLX dispatch/combine kernels use.
"""
def
__init__
(
self
,
max_num_tokens
:
Optional
[
int
],
world_size
:
int
,
dp_size
:
int
,
rank
:
int
):
super
().
__init__
()
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
self
.
rank
=
rank
self
.
max_num_tokens
=
max_num_tokens
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
assert
a1
.
dim
()
==
2
assert
topk_ids
.
dim
()
==
2
assert
topk_ids
.
size
(
0
)
==
a1
.
size
(
0
)
if
apply_router_weight_on_input
:
topk
=
topk_ids
.
size
(
1
)
# TODO: this only works for topK=1, will need to update for topK>1
assert
topk
==
1
,
\
"apply_router_weight_on_input is only implemented for topk=1"
a1
.
mul_
(
topk_weights
.
to
(
a1
.
dtype
))
num_tokens
,
hidden_dim
=
a1
.
size
()
topk
=
topk_ids
.
size
(
1
)
if
self
.
max_num_tokens
is
None
:
tokens_per_expert
=
torch
.
bincount
(
topk_ids
.
view
(
-
1
),
minlength
=
num_experts
)
self
.
max_num_tokens
=
int
(
tokens_per_expert
.
max
().
item
())
else
:
tokens_per_expert
=
torch
.
zeros
(
num_experts
,
dtype
=
torch
.
int
,
device
=
a1
.
device
)
assert
num_experts
%
self
.
world_size
==
0
num_local_experts
=
num_experts
//
self
.
world_size
b_a1
=
torch
.
zeros
(
(
num_local_experts
,
self
.
max_num_tokens
,
hidden_dim
),
dtype
=
a1
.
dtype
,
device
=
a1
.
device
)
first_expert
=
num_local_experts
*
self
.
rank
last_expert
=
first_expert
+
num_local_experts
for
expert_id
in
range
(
first_expert
,
last_expert
):
topks
=
torch
.
any
(
topk_ids
==
expert_id
,
dim
=
1
).
flatten
()
rows
=
torch
.
count_nonzero
(
topks
.
flatten
())
b_a1
[
expert_id
-
first_expert
,
:
rows
,
:]
=
a1
[:
topks
.
numel
()][
topks
]
tokens_per_expert
[
expert_id
-
first_expert
]
=
rows
return
b_a1
,
a1_scale
,
tokens_per_expert
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
)
->
None
:
num_tokens
=
topk_ids
.
size
(
0
)
num_local_experts
=
fused_expert_output
.
size
(
0
)
K
=
fused_expert_output
.
size
(
-
1
)
assert
output
.
size
(
0
)
==
num_tokens
and
output
.
size
(
1
)
==
K
output
.
fill_
(
0
)
first_expert
=
num_local_experts
*
self
.
rank
last_expert
=
first_expert
+
num_local_experts
for
expert_id
in
range
(
first_expert
,
last_expert
):
matching_tokens
=
topk_ids
==
expert_id
topks
=
torch
.
any
(
matching_tokens
,
dim
=
1
).
flatten
()
rows
=
torch
.
count_nonzero
(
topks
)
rhs
=
fused_expert_output
[
expert_id
-
first_expert
,
:
rows
,
:]
if
not
apply_router_weight_on_input
:
rhs
.
mul_
(
topk_weights
[
matching_tokens
].
view
(
rhs
.
size
(
0
),
1
))
output
[
topks
]
=
output
[
topks
]
+
rhs
class
BatchedExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
"""
A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
dispatch/combine kernels use.
"""
def
__init__
(
self
,
world_size
:
int
,
dp_size
:
int
,
max_num_tokens
:
Optional
[
int
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_m
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
assert
block_shape
is
None
assert
block_m
is
None
assert
not
use_fp8_w8a8
,
"NYI"
assert
not
use_int8_w8a8
,
"NYI"
assert
not
use_int8_w8a16
,
"NYI"
assert
not
use_int4_w4a16
,
"NYI"
self
.
max_num_tokens
=
max_num_tokens
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
num_experts
:
int
,
)
->
tuple
[
int
,
int
,
torch
.
dtype
]:
assert
a
.
dim
()
==
2
num_dp
=
self
.
world_size
//
self
.
dp_size
max_num_tokens
=
a
.
size
(
0
)
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
workspace13
=
num_experts
*
max_num_tokens
*
num_dp
*
K
workspace2
=
max_num_tokens
*
num_dp
*
N
return
(
workspace13
,
workspace2
,
a
.
dtype
)
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
assert
hidden_states
.
dim
()
==
3
assert
expert_num_tokens
is
not
None
hidden_dim
=
hidden_states
.
size
(
-
1
)
if
self
.
max_num_tokens
is
None
:
max_num_tokens
=
hidden_states
.
size
(
1
)
else
:
max_num_tokens
=
self
.
max_num_tokens
num_dp
=
self
.
world_size
//
self
.
dp_size
num_experts
=
global_num_experts
out
=
_resize_cache
(
workspace13
,
(
num_experts
,
max_num_tokens
*
num_dp
,
hidden_dim
))
num_local_experts
=
w1
.
size
(
0
)
assert
num_local_experts
==
w1
.
size
(
0
),
(
f
"
{
num_local_experts
}
==
{
w1
.
size
(
0
)
}
"
)
N
=
w1
.
size
(
1
)
//
2
# Not cudagraph friendly
assert
(
torch
.
compiler
.
is_compiling
()
or
torch
.
cuda
.
is_current_stream_capturing
()
or
torch
.
all
(
expert_num_tokens
<=
max_num_tokens
*
num_dp
)),
(
f
"
{
expert_num_tokens
}
<=
{
max_num_tokens
*
num_dp
}
"
)
for
expert
in
range
(
num_local_experts
):
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
if
(
torch
.
compiler
.
is_compiling
()
or
torch
.
cuda
.
is_current_stream_capturing
()):
num
=
max_num_tokens
*
num_dp
else
:
num
=
int
(
expert_num_tokens
[
expert
].
item
())
tmp
=
_resize_cache
(
workspace2
,
(
num
,
N
))
input
=
hidden_states
[
expert
,
:
num
,
:]
@
w1
[
expert
].
transpose
(
0
,
1
)
self
.
activation
(
activation
,
tmp
,
input
)
out
[
expert
,
:
num
,
:]
=
tmp
@
w2
[
expert
].
transpose
(
0
,
1
)
return
out
class
BatchedTritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
"""
A Triton based MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
dispatch/combine kernels use.
"""
def
__init__
(
self
,
max_num_tokens
:
Optional
[
int
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
world_size
:
int
=
1
,
dp_size
:
int
=
1
,
):
super
().
__init__
()
self
.
use_fp8_w8a8
=
use_fp8_w8a8
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
block_shape
=
block_shape
self
.
max_num_tokens
=
max_num_tokens
assert
not
use_int8_w8a8
,
"NYI"
assert
not
use_int4_w4a16
,
"NYI"
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
num_experts
:
int
,
)
->
tuple
[
int
,
int
,
torch
.
dtype
]:
assert
a
.
dim
()
==
2
num_dp
=
self
.
world_size
//
self
.
dp_size
max_num_tokens
=
a
.
size
(
0
)
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
workspace13
=
num_experts
*
max_num_tokens
*
num_dp
*
max
(
K
,
N
)
workspace2
=
num_experts
*
max_num_tokens
*
num_dp
*
(
N
//
2
)
return
(
workspace13
,
workspace2
,
a
.
dtype
)
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Check constraints.
if
self
.
use_int4_w4a16
:
assert
hidden_states
.
size
(
-
1
)
//
2
==
w1
.
size
(
2
),
(
"Hidden size mismatch"
)
else
:
assert
hidden_states
.
size
(
-
1
)
==
w1
.
size
(
2
),
(
f
"Hidden size mismatch
{
hidden_states
.
size
(
-
1
)
}
"
f
"!=
{
w1
.
size
(
2
)
}
"
)
assert
hidden_states
.
is_contiguous
(
),
"Hidden_states must be contiguous"
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float8_e4m3fn
]
# TODO: num_tokens -> max_num_tokens?
E
,
num_tokens
,
N
,
K
,
top_k_num
=
mk
.
_moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
)
assert
w1
.
size
(
0
)
==
E
assert
w2
.
size
(
0
)
==
E
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
config
=
try_get_optimal_moe_config
(
w1
.
size
(),
w2
.
size
(),
top_k_num
,
config_dtype
,
num_tokens
,
block_shape
=
self
.
block_shape
,
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
:
compute_type
=
tl
.
bfloat16
elif
hidden_states
.
dtype
==
torch
.
float16
:
compute_type
=
tl
.
float16
elif
hidden_states
.
dtype
==
torch
.
float32
:
compute_type
=
tl
.
float32
elif
hidden_states
.
dtype
==
torch
.
float8_e4m3fn
:
compute_type
=
tl
.
bfloat16
else
:
raise
ValueError
(
f
"Unsupported compute_type:
{
hidden_states
.
dtype
}
"
)
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1
=
_resize_cache
(
workspace13
,
(
E
,
num_tokens
,
N
))
intermediate_cache2
=
_resize_cache
(
workspace2
,
(
E
,
num_tokens
,
N
//
2
))
intermediate_cache3
=
_resize_cache
(
workspace13
,
(
E
,
num_tokens
,
K
))
# MM1
invoke_moe_batched_triton_kernel
(
A
=
hidden_states
,
B
=
w1
,
C
=
intermediate_cache1
,
expert_num_tokens
=
expert_num_tokens
,
compute_type
=
compute_type
,
A_scale
=
a1q_scale
,
B_scale
=
w1_scale
,
B_zp
=
w1_zp
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
config
=
config
,
block_shape
=
self
.
block_shape
)
# TODO: would be nice to use expert_num_tokens here to reduce
# garbage compute
self
.
activation
(
activation
,
intermediate_cache2
.
view
(
-
1
,
N
//
2
),
intermediate_cache1
.
view
(
-
1
,
N
))
#qintermediate_cache2 = intermediate_cache2
a2q_scale
=
a2_scale
# TODO (varun) : support w8a8
assert
not
self
.
use_fp8_w8a8
#if self.use_fp8_w8a8:
# qintermediate_cache2, a2q_scale = _fp8_quantize(
# intermediate_cache2, a2_scale, self.block_shape)
invoke_moe_batched_triton_kernel
(
A
=
intermediate_cache2
,
B
=
w2
,
C
=
intermediate_cache3
,
expert_num_tokens
=
expert_num_tokens
,
compute_type
=
compute_type
,
A_scale
=
a2q_scale
,
B_scale
=
w2_scale
,
B_zp
=
w2_zp
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
config
=
config
,
block_shape
=
self
.
block_shape
)
return
intermediate_cache3
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
f9c069c8
...
@@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
...
@@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
moe_align_block_size
)
from
vllm.model_executor.layers.
quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.
fused_moe.prepare_finalize
import
(
per_token_group_quant_fp8
)
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.
quantization.utils.int8_
utils
import
(
from
vllm.model_executor.layers.
fused_moe.
utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
_resize_cache
,
moe_kernel_quantize_input
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
or
use_int8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
shape
[
0
]
M
=
A
.
shape
[
0
]
num_tokens
=
M
*
top_k
num_tokens
=
M
*
top_k
...
@@ -855,6 +870,7 @@ def fused_topk(
...
@@ -855,6 +870,7 @@ def fused_topk(
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
indices_type
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
"Number of tokens mismatch"
)
...
@@ -865,9 +881,10 @@ def fused_topk(
...
@@ -865,9 +881,10 @@ def fused_topk(
topk
,
topk
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk_ids
=
torch
.
empty
(
M
,
topk
,
topk
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
if
indices_type
is
None
else
indices_type
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
token_expert_indices
=
torch
.
empty
(
M
,
token_expert_indices
=
torch
.
empty
(
M
,
topk
,
topk
,
...
@@ -962,6 +979,20 @@ def get_config_dtype_str(
...
@@ -962,6 +979,20 @@ def get_config_dtype_str(
return
None
return
None
# TODO (bnell): use scalar_type instead of bools?
def
get_config_qtype
(
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
)
->
Optional
[
torch
.
dtype
]:
if
use_fp8_w8a8
:
return
torch
.
float8_e4m3fn
elif
use_int8_w8a8
:
return
torch
.
int8
return
None
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
)
->
torch
.
Tensor
:
allow_deep_gemm
:
bool
=
False
)
->
torch
.
Tensor
:
if
(
allow_deep_gemm
and
use_fp8_w8a8
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N
=
w1
.
shape
[
1
]
if
(
allow_deep_gemm
and
use_fp8_w8a8
and
N
>
512
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)):
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)):
assert
apply_router_weight_on_input
is
False
assert
apply_router_weight_on_input
is
False
return
deep_gemm_moe_fp8
(
return
deep_gemm_moe_fp8
(
...
@@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
)
else
:
else
:
return
dispatch_fused_experts_func
(
inplace
)(
return
dispatch_fused_experts_func
(
inplace
)(
...
@@ -1171,60 +1206,8 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1171,60 +1206,8 @@ def fused_experts(hidden_states: torch.Tensor,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
def
moe_kernel_prepare_input
(
def
fused_experts_impl
(
A
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
# activation block-wise fp8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# activation channel-wise int8 quantization
assert
(
per_channel_quant
),
"int8 quantization only supports block or channel-wise"
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
# activation block-wise int8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
return
A
,
A_scale
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
...
@@ -1245,13 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1245,13 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
):
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
# Check constraints.
# Check constraints.
if
use_int4_w4a16
:
if
use_int4_w4a16
:
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
2
],
"Hidden size mismatch"
2
],
"Hidden size mismatch"
else
:
else
:
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
(
f
"Hidden size mismatch
{
hidden_states
.
shape
[
1
]
}
!=
{
w1
.
shape
[
2
]
}
"
)
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
...
@@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
]
num_tokens
,
_
=
hidden_states
.
shape
num_tokens
=
hidden_states
.
shape
[
0
]
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
...
@@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
qtype
=
get_config_qtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
)
get_config_func
=
functools
.
partial
(
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
try_get_optimal_moe_config
,
w1
.
shape
,
w1
.
shape
,
...
@@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
qcurr_hidden_states
,
q
a1_scale
=
moe_kernel_
prepar
e_input
(
qcurr_hidden_states
,
a1
q
_scale
=
moe_kernel_
quantiz
e_input
(
A
=
curr_hidden_states
,
A
=
curr_hidden_states
,
B
=
w1
,
A_scale
=
a1_scale
,
A_scale
=
a1_scale
,
B_scale
=
w1_scale
,
qtype
=
qtype
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
...
@@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
w1
,
intermediate_cache1
,
intermediate_cache1
,
q
a1_scale
,
a1
q
_scale
,
w1_scale
,
w1_scale
,
w1_zp
,
w1_zp
,
curr_topk_weights
,
curr_topk_weights
,
...
@@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else
:
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
qintermediate_cache2
,
q
a2_scale
=
moe_kernel_
prepar
e_input
(
qintermediate_cache2
,
a2
q
_scale
=
moe_kernel_
quantiz
e_input
(
A
=
intermediate_cache2
,
A
=
intermediate_cache2
,
B
=
w2
,
A_scale
=
a2_scale
,
A_scale
=
a2_scale
,
B_scale
=
w2_scale
,
qtype
=
qtype
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
w2
,
intermediate_cache3
,
intermediate_cache3
,
q
a2_scale
,
a2
q
_scale
,
w2_scale
,
w2_scale
,
w2_zp
,
w2_zp
,
curr_topk_weights
,
curr_topk_weights
,
...
@@ -1534,3 +1514,209 @@ def fused_moe(
...
@@ -1534,3 +1514,209 @@ def fused_moe(
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
class
TritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_m
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
use_fp8_w8a8
=
use_fp8_w8a8
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
block_shape
=
block_shape
self
.
block_m
=
block_m
self
.
qtype
=
get_config_qtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
)
self
.
per_channel_quant
=
per_channel_quant
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
num_experts
:
int
,
)
->
tuple
[
int
,
int
,
torch
.
dtype
]:
factor
=
num_experts
if
a
.
dim
()
==
3
else
1
workspace1
=
M
*
topk
*
max
(
N
*
2
,
K
)
*
factor
workspace2
=
M
*
topk
*
N
*
factor
return
(
workspace1
,
workspace2
,
a
.
dtype
)
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Check constraints.
if
self
.
use_int4_w4a16
:
assert
hidden_states
.
size
(
-
1
)
//
2
==
w1
.
size
(
2
),
(
"Hidden size mismatch"
)
else
:
assert
hidden_states
.
size
(
-
1
)
==
w1
.
size
(
2
),
\
(
f
"Hidden size mismatch
{
hidden_states
.
size
(
-
1
)
}
"
f
"!=
{
w1
.
size
(
2
)
}
"
)
assert
hidden_states
.
is_contiguous
(
),
"Hidden_states must be contiguous"
assert
hidden_states
.
dim
()
==
2
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float8_e4m3fn
]
E
,
num_tokens
,
N
,
K
,
top_k_num
=
mk
.
_moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
config
=
try_get_optimal_moe_config
(
w1
.
shape
,
w2
.
shape
,
top_k_num
,
config_dtype
,
num_tokens
,
block_shape
=
self
.
block_shape
,
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
:
compute_type
=
tl
.
bfloat16
elif
hidden_states
.
dtype
==
torch
.
float16
:
compute_type
=
tl
.
float16
elif
hidden_states
.
dtype
==
torch
.
float32
:
compute_type
=
tl
.
float32
elif
hidden_states
.
dtype
==
torch
.
float8_e4m3fn
:
compute_type
=
tl
.
bfloat16
else
:
raise
ValueError
(
f
"Unsupported compute_type:
{
hidden_states
.
dtype
}
"
)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1
=
_resize_cache
(
workspace13
,
(
num_tokens
,
top_k_num
,
N
))
intermediate_cache2
=
_resize_cache
(
workspace2
,
(
num_tokens
*
top_k_num
,
N
//
2
))
intermediate_cache3
=
_resize_cache
(
workspace13
,
(
num_tokens
,
top_k_num
,
K
))
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1q_scale
,
w1_scale
,
w1_zp
,
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
top_k_num
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
per_channel_quant
=
self
.
per_channel_quant
,
block_shape
=
self
.
block_shape
)
self
.
activation
(
activation
,
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qintermediate_cache2
,
a2q_scale
=
moe_kernel_quantize_input
(
intermediate_cache2
,
a2_scale
,
self
.
qtype
,
self
.
per_channel_quant
,
self
.
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2q_scale
,
w2_scale
,
w2_zp
,
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
per_channel_quant
=
self
.
per_channel_quant
,
block_shape
=
self
.
block_shape
)
return
intermediate_cache3
def
modular_triton_fused_moe
(
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
mk
.
FusedMoEModularKernel
:
qtype
=
get_config_qtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
)
return
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
quant_dtype
=
qtype
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
),
TritonExperts
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
),
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment