Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f9c069c8
Unverified
Commit
f9c069c8
authored
May 14, 2025
by
bnellnm
Committed by
GitHub
May 14, 2025
Browse files
Modularize fused experts and integrate PPLX kernels (#15956)
parent
418d2f8b
Changes
42
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2425 additions
and
539 deletions
+2425
-539
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+3
-0
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+14
-0
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+4
-4
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+45
-18
examples/offline_inference/data_parallel.py
examples/offline_inference/data_parallel.py
+16
-6
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+114
-0
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+21
-25
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+51
-42
tests/kernels/moe/test_pplx_moe.py
tests/kernels/moe/test_pplx_moe.py
+691
-0
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
+20
-14
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+10
-10
tests/kernels/quantization/test_block_int8.py
tests/kernels/quantization/test_block_int8.py
+4
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+51
-2
vllm/distributed/utils.py
vllm/distributed/utils.py
+6
-7
vllm/forward_context.py
vllm/forward_context.py
+4
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+3
-2
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+195
-108
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+131
-198
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+755
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+287
-101
No files found.
csrc/activation_kernels.cu
View file @
f9c069c8
...
@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
...
@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \
...
...
csrc/dispatch_utils.h
View file @
f9c069c8
...
@@ -65,5 +65,19 @@
...
@@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
csrc/moe/moe_align_sum_kernels.cu
View file @
f9c069c8
...
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
}
if
(
use_global_memory
)
{
if
(
use_global_memory
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_global_mem_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_global_mem_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
// tensors
...
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer
.
data_ptr
<
int32_t
>
());
cumsum_buffer
.
data_ptr
<
int32_t
>
());
});
});
}
else
if
(
use_i16
)
{
}
else
if
(
use_i16
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// set dynamic shared mem
// set dynamic shared mem
auto
kernel
=
auto
kernel
=
...
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids
.
numel
());
topk_ids
.
numel
());
});
});
}
else
{
}
else
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
auto
kernel
=
vllm
::
moe
::
moe_align_block_size_kernel
<
scalar_t
,
int32_t
>
;
vllm
::
moe
::
moe_align_block_size_kernel
<
scalar_t
,
int32_t
>
;
...
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK
(
num_experts
==
256
,
TORCH_CHECK
(
num_experts
==
256
,
"sgl_moe_align_block_size kernel only supports deepseek v3."
);
"sgl_moe_align_block_size kernel only supports deepseek v3."
);
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `cumsum` tensors
// calc needed amount of shared mem for `cumsum` tensors
auto
options_int
=
auto
options_int
=
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
f9c069c8
...
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
...
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}
}
}
template
<
int
TPB
>
template
<
int
TPB
,
typename
IndType
>
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
int
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
{
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
...
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
...
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
2) This implementation assumes k is small, but will work for any k.
*/
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
>
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
typename
IndType
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
int
*
indices
,
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
{
// We begin by enforcing compile time assertions and setting up compile time constants.
// We begin by enforcing compile time assertions and setting up compile time constants.
...
@@ -397,8 +405,8 @@ struct TopkConstants
...
@@ -397,8 +405,8 @@ struct TopkConstants
};
};
}
// namespace detail
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
>
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
typename
IndType
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
int
*
indices
,
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
{
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
...
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
...
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
stream);
template
<
typename
IndType
>
void
topkGatingSoftmaxKernelLauncher
(
void
topkGatingSoftmaxKernelLauncher
(
const
float
*
gating_output
,
const
float
*
gating_output
,
float
*
topk_weights
,
float
*
topk_weights
,
int
*
topk_indicies
,
IndType
*
topk_indicies
,
int
*
token_expert_indices
,
int
*
token_expert_indices
,
float
*
softmax_workspace
,
float
*
softmax_workspace
,
const
int
num_tokens
,
const
int
num_tokens
,
...
@@ -493,6 +502,9 @@ void topk_softmax(
...
@@ -493,6 +502,9 @@ void topk_softmax(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
...
@@ -503,4 +515,19 @@ void topk_softmax(
...
@@ -503,4 +515,19 @@ void topk_softmax(
num_experts
,
num_experts
,
topk
,
topk
,
stream
);
stream
);
}
else
{
assert
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
UInt32
);
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
uint32_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
}
}
examples/offline_inference/data_parallel.py
View file @
f9c069c8
...
@@ -65,11 +65,17 @@ def parse_args():
...
@@ -65,11 +65,17 @@ def parse_args():
type
=
int
,
type
=
int
,
default
=
0
,
default
=
0
,
help
=
"Master node port"
)
help
=
"Master node port"
)
parser
.
add_argument
(
"--enforce-eager"
,
action
=
'store_true'
,
help
=
"Enforce eager mode execution."
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
'store_true'
,
help
=
"Trust remote code."
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
main
(
model
,
dp_size
,
local_dp_rank
,
global_dp_rank
,
dp_master_ip
,
def
main
(
model
,
dp_size
,
local_dp_rank
,
global_dp_rank
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
):
dp_master_port
,
GPUs_per_dp_rank
,
enforce_eager
,
trust_remote_code
):
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
global_dp_rank
)
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
global_dp_rank
)
os
.
environ
[
"VLLM_DP_RANK_LOCAL"
]
=
str
(
local_dp_rank
)
os
.
environ
[
"VLLM_DP_RANK_LOCAL"
]
=
str
(
local_dp_rank
)
os
.
environ
[
"VLLM_DP_SIZE"
]
=
str
(
dp_size
)
os
.
environ
[
"VLLM_DP_SIZE"
]
=
str
(
dp_size
)
...
@@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
...
@@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
max_tokens
=
[
16
,
20
][
global_dp_rank
%
2
])
max_tokens
=
[
16
,
20
][
global_dp_rank
%
2
])
# Create an LLM.
# Create an LLM.
llm
=
LLM
(
model
=
model
,
llm
=
LLM
(
model
=
model
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
enforce_eager
=
True
,
enforce_eager
=
enforce_eager
,
enable_expert_parallel
=
True
)
enable_expert_parallel
=
True
,
trust_remote_code
=
trust_remote_code
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
# Print the outputs.
for
i
,
output
in
enumerate
(
outputs
):
for
i
,
output
in
enumerate
(
outputs
):
...
@@ -155,7 +164,8 @@ if __name__ == "__main__":
...
@@ -155,7 +164,8 @@ if __name__ == "__main__":
proc
=
Process
(
target
=
main
,
proc
=
Process
(
target
=
main
,
args
=
(
args
.
model
,
dp_size
,
local_dp_rank
,
args
=
(
args
.
model
,
dp_size
,
local_dp_rank
,
global_dp_rank
,
dp_master_ip
,
dp_master_port
,
global_dp_rank
,
dp_master_ip
,
dp_master_port
,
tp_size
))
tp_size
,
args
.
enforce_eager
,
args
.
trust_remote_code
))
proc
.
start
()
proc
.
start
()
procs
.
append
(
proc
)
procs
.
append
(
proc
)
exit_code
=
0
exit_code
=
0
...
...
tests/kernels/moe/test_batched_moe.py
0 → 100644
View file @
f9c069c8
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
pytest
import
torch
import
triton.language
as
tl
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
invoke_moe_batched_triton_kernel
)
@
dataclass
class
BatchedMMConfig
:
dtype
:
torch
.
dtype
num_experts
:
int
max_tokens_per_expert
:
int
K
:
int
N
:
int
@
dataclass
class
BatchedMMTensors
:
A
:
torch
.
Tensor
# [E, max_tokens, K]
B
:
torch
.
Tensor
# [E, K, N] - column major
C
:
torch
.
Tensor
# [E, max_tokens, N]
num_expert_tokens
:
torch
.
Tensor
# [E]
@
staticmethod
def
make_tensors
(
config
:
BatchedMMConfig
):
A
=
torch
.
randn
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
K
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
/
10
B
=
torch
.
randn
((
config
.
num_experts
,
config
.
N
,
config
.
K
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
C
=
torch
.
zeros
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
N
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
num_expert_tokens
=
torch
.
randint
(
low
=
0
,
high
=
config
.
max_tokens_per_expert
,
size
=
(
config
.
num_experts
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
return
BatchedMMTensors
(
A
,
B
,
C
,
num_expert_tokens
)
def
ref_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
num_expert_tokens
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_expert_tokens_cpu
=
num_expert_tokens
.
clone
()
num_expert_tokens_cpu
=
num_expert_tokens_cpu
.
to
(
device
=
"cpu"
)
num_experts
=
num_expert_tokens
.
size
(
0
)
for
e
in
range
(
num_experts
):
num_tokens
=
num_expert_tokens_cpu
[
e
]
C
[
e
,
:
num_tokens
,
:]
=
A
[
e
,
:
num_tokens
,
:]
@
B
[
e
].
transpose
(
0
,
1
)
return
C
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
64
,
128
,
192
,
224
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
):
config
=
BatchedMMConfig
(
dtype
,
num_experts
,
max_tokens_per_expert
,
K
,
N
)
tensors
=
BatchedMMTensors
.
make_tensors
(
config
)
test_output
=
tensors
.
C
ref_output
=
test_output
.
clone
()
compute_tl_dtype
=
{
torch
.
float16
:
tl
.
float16
,
torch
.
bfloat16
:
tl
.
bfloat16
,
torch
.
float32
:
tl
.
float32
}[
test_output
.
dtype
]
invoke_moe_batched_triton_kernel
(
tensors
.
A
,
tensors
.
B
,
test_output
,
tensors
.
num_expert_tokens
,
compute_tl_dtype
,
# Quantization data
None
,
None
,
None
,
# Quantization schemes
False
,
False
,
False
,
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
16
})
ref_output
=
ref_impl
(
tensors
.
A
,
tensors
.
B
,
ref_output
,
tensors
.
num_expert_tokens
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
torch
.
bfloat16
:
(
6e-2
,
6e-2
),
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
test_output
.
dtype
]
torch
.
testing
.
assert_close
(
test_output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
tests/kernels/moe/test_cutlass_moe.py
View file @
f9c069c8
...
@@ -30,6 +30,11 @@ MNK_FACTORS = [
...
@@ -30,6 +30,11 @@ MNK_FACTORS = [
(
224
,
3072
,
1536
),
(
224
,
3072
,
1536
),
]
]
vllm_config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MOETensors
:
class
MOETensors
:
...
@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
...
@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'w1_q'
:
moe_tensors
.
w1_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w1_q'
:
moe_tensors
.
w1_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w2_q'
:
moe_tensors
.
w2_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'w2_q'
:
moe_tensors
.
w2_q
.
transpose
(
1
,
2
),
# type: ignore[union-attr]
'topk_weights'
:
topk_weights
,
'topk_weights'
:
topk_weights
,
'topk_ids
_
'
:
topk_ids
,
'topk_ids'
:
topk_ids
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
...
@@ -231,15 +236,12 @@ def test_cutlass_moe_8_bit_no_graph(
...
@@ -231,15 +236,12 @@ def test_cutlass_moe_8_bit_no_graph(
per_out_ch
:
bool
,
per_out_ch
:
bool
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_ch
)
per_out_ch
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
...
@@ -276,17 +278,14 @@ def test_cutlass_moe_8_bit_cuda_graph(
...
@@ -276,17 +278,14 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_out_ch
:
bool
,
per_out_ch
:
bool
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
dtype
=
torch
.
half
dtype
=
torch
.
half
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_ch
)
per_out_ch
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
...
@@ -334,15 +333,12 @@ def test_cutlass_moe_8_bit_EP(
...
@@ -334,15 +333,12 @@ def test_cutlass_moe_8_bit_EP(
ep_size
:
int
,
ep_size
:
int
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
with
set_current_vllm_config
(
vllm_config
):
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_channel
)
per_out_channel
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
half
)
topk_weights
,
topk_ids
=
fused_topk
(
mt
.
a
,
topk_weights
,
topk_ids
,
_
=
fused_topk
(
mt
.
a
,
score
,
score
,
topk
,
topk
,
renormalize
=
False
)
renormalize
=
False
)
...
...
tests/kernels/moe/test_moe.py
View file @
f9c069c8
...
@@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
...
@@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import
vllm.model_executor.layers.fused_moe
# noqa
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
...
@@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
...
@@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
EP_SIZE
=
[
1
,
4
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
2
,
6
]
TOP_KS
=
[
2
,
6
]
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
1024
*
128
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
1024
*
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
...
@@ -70,6 +75,7 @@ def test_fused_moe(
...
@@ -70,6 +75,7 @@ def test_fused_moe(
else
:
else
:
e_map
=
None
e_map
=
None
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
iterative_output
=
iterative_moe
(
a
,
iterative_output
=
iterative_moe
(
a
,
w1
,
w1
,
...
@@ -95,6 +101,7 @@ def test_fused_moe(
...
@@ -95,6 +101,7 @@ def test_fused_moe(
global_num_experts
=
e
,
global_num_experts
=
e
,
expert_map
=
e_map
,
expert_map
=
e_map
,
renormalize
=
False
)
renormalize
=
False
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
torch_output
,
...
@@ -115,7 +122,6 @@ def test_fused_moe(
...
@@ -115,7 +122,6 @@ def test_fused_moe(
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
weight_bits
:
int
):
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
@@ -194,6 +200,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -194,6 +200,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else
:
else
:
e_map
=
None
e_map
=
None
with
set_current_vllm_config
(
vllm_config
):
triton_output
=
fused_moe
(
a
,
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w1_qweight
,
w2_qweight
,
w2_qweight
,
...
@@ -210,6 +217,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -210,6 +217,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
...
@@ -515,6 +523,7 @@ def test_fused_marlin_moe(
...
@@ -515,6 +523,7 @@ def test_fused_marlin_moe(
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
...
...
tests/kernels/moe/test_pplx_moe.py
0 → 100644
View file @
f9c069c8
This diff is collapsed.
Click to expand it.
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
View file @
f9c069c8
...
@@ -7,6 +7,7 @@ import pytest
...
@@ -7,6 +7,7 @@ import pytest
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
...
@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
def
native_w8a8_per_token_matmul
(
A
,
B
,
As
,
Bs
,
output_dtype
=
torch
.
float16
):
def
native_w8a8_per_token_matmul
(
A
,
B
,
As
,
Bs
,
output_dtype
=
torch
.
float16
):
"""Matrix multiplication function that supports per-token input
"""Matrix multiplication function that supports per-token input
...
@@ -137,6 +142,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
...
@@ -137,6 +142,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w2_s
=
torch
.
rand
(
E
,
K
,
device
=
w2_fp32
.
device
)
*
factor_for_scale
w2_s
=
torch
.
rand
(
E
,
K
,
device
=
w2_fp32
.
device
)
*
factor_for_scale
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
with
set_current_vllm_config
(
vllm_config
):
ref_out
=
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
)
ref_out
=
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
)
out
=
fused_moe
(
out
=
fused_moe
(
a
,
a
,
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
f9c069c8
...
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
...
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
deep_gemm_moe_fp8
)
_valid_deep_gemm_shape
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
moe_align_block_size
)
...
@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
...
@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
# Test configurations
# Test configurations
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS
=
[
7
,
83
,
2048
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
...
@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
...
@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
# Set the context to avoid lots of warning spam.
# Set the context to avoid lots of warning spam.
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
out
=
fused_moe
(
a
,
a
,
...
@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
...
@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
"M,N,K,block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
def
test_w8a8_block_fp8_deep_gemm_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
# only aligned sizes
# only aligned sizes
...
@@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
...
@@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_size
=
[
block_m
,
block_m
]
block_size
=
[
block_m
,
block_m
]
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
# only aligned sizes
if
topk
>
E
:
if
(
N
%
block_m
!=
0
or
K
%
block_m
!=
0
or
topk
>
E
):
pytest
.
skip
(
f
"Skipping test: topk=
{
topk
}
> E=
{
E
}
"
)
pytest
.
skip
(
f
"Skipping test; bad size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
, topk=
{
topk
}
, E=
{
E
}
"
)
if
N
<=
512
:
pytest
.
skip
(
"Skipping N <= 512 until performance issues solved."
)
vllm_config
=
VllmConfig
()
if
not
_valid_deep_gemm_shape
(
M
,
N
,
K
):
pytest
.
skip
(
f
"Skipping test: invalid size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
"
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
...
...
tests/kernels/quantization/test_block_int8.py
View file @
f9c069c8
...
@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
...
@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
pytest
.
skip
(
"INT8 Triton requires CUDA 7.0 or higher"
,
pytest
.
skip
(
"INT8 Triton requires CUDA 7.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
# For test
# For test
def
native_per_token_group_quant_int8
(
x
,
def
native_per_token_group_quant_int8
(
x
,
...
@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
...
@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
# Set the context to avoid lots of warning spam.
# Set the context to avoid lots of warning spam.
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
out
=
fused_moe
(
a
,
a
,
...
...
vllm/distributed/parallel_state.py
View file @
f9c069c8
...
@@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
...
@@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
"""
"""
import
contextlib
import
contextlib
import
gc
import
gc
import
importlib.util
import
pickle
import
pickle
import
weakref
import
weakref
from
collections
import
namedtuple
from
collections
import
namedtuple
...
@@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
...
@@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
direct_register_custom_op
,
resolve_obj_by_qualname
,
from
vllm.utils
import
(
direct_register_custom_op
,
resolve_obj_by_qualname
,
supports_custom_op
)
run_once
,
supports_custom_op
)
@
dataclass
@
dataclass
...
@@ -936,9 +937,49 @@ def init_distributed_environment(
...
@@ -936,9 +937,49 @@ def init_distributed_environment(
"world group already initialized with a different world size"
)
"world group already initialized with a different world size"
)
PPLX_DID_INIT
:
bool
=
False
@
run_once
def
pplx_init
(
rank
,
world_size
):
has_pplx
=
importlib
.
util
.
find_spec
(
"pplx_kernels"
)
is
not
None
if
has_pplx
and
world_size
>
1
:
from
pplx_kernels.nvshmem
import
(
nvshmem_alloc_empty_unique_id
,
nvshmem_get_unique_id
,
nvshmem_init
)
try
:
global
PPLX_DID_INIT
logger
.
debug
(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d"
,
rank
,
world_size
)
uid
=
nvshmem_get_unique_id
(
)
if
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
uid_gpu
=
uid
.
cuda
()
get_world_group
().
broadcast
(
uid_gpu
,
src
=
0
)
uid
=
uid_gpu
.
to
(
device
=
'cpu'
)
logger
.
debug
(
"PPLX NVSHMEM UID = %s"
,
uid
)
nvshmem_init
(
uid
,
rank
,
world_size
)
PPLX_DID_INIT
=
True
except
Exception
as
ex
:
logger
.
error
(
"Failed to initialize NVSHMEM for PPLX: %s"
,
ex
)
@
run_once
def
pplx_finalize
():
global
PPLX_DID_INIT
if
PPLX_DID_INIT
:
from
pplx_kernels.nvshmem
import
nvshmem_finalize
logger
.
debug
(
"PPLX NVSHMEM finalize"
)
from
vllm.model_executor.layers.fused_moe.layer
import
(
_all_to_all_cache
)
_all_to_all_cache
.
destroy
()
nvshmem_finalize
()
def
initialize_model_parallel
(
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
enable_expert_parallel
:
bool
=
False
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -1041,10 +1082,14 @@ def initialize_model_parallel(
...
@@ -1041,10 +1082,14 @@ def initialize_model_parallel(
_DP
.
rank_in_group
,
_PP
.
rank_in_group
,
_TP
.
rank_in_group
,
_DP
.
rank_in_group
,
_PP
.
rank_in_group
,
_TP
.
rank_in_group
,
_EP
.
rank_in_group
)
_EP
.
rank_in_group
)
if
enable_expert_parallel
:
pplx_init
(
rank
,
world_size
)
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
tensor_model_parallel_size
:
int
,
tensor_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
enable_expert_parallel
:
bool
=
False
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""Helper to initialize model parallel groups if they are not initialized,
"""Helper to initialize model parallel groups if they are not initialized,
...
@@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
...
@@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
get_world_group
().
device_group
)
get_world_group
().
device_group
)
if
not
model_parallel_is_initialized
():
if
not
model_parallel_is_initialized
():
initialize_model_parallel
(
tensor_model_parallel_size
,
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
,
backend
)
pipeline_model_parallel_size
,
enable_expert_parallel
,
backend
)
return
return
assert
(
assert
(
...
@@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
...
@@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none and destroy them."""
"""Set the groups to none and destroy them."""
global
_TP
global
_TP
pplx_finalize
()
if
_TP
:
if
_TP
:
_TP
.
destroy
()
_TP
.
destroy
()
_TP
=
None
_TP
=
None
...
...
vllm/distributed/utils.py
View file @
f9c069c8
...
@@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
...
@@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_tcp_uri
from
vllm.utils
import
get_tcp_uri
,
is_torch_equal_or_newer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
...
@@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
Destroy ProcessGroup returned by
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
stateless_init_torch_distributed_process_group().
"""
"""
if
is_torch_equal_or_newer
(
"2.7"
):
pg
.
shutdown
()
else
:
# Lazy import for non-CUDA backends.
# Lazy import for non-CUDA backends.
try
:
# pytorch <= 2.6
from
torch.distributed.distributed_c10d
import
_shutdown_backend
from
torch.distributed.distributed_c10d
import
_shutdown_backend
_shutdown_backend
(
pg
)
_shutdown_backend
(
pg
)
except
ImportError
:
# pytorch >= 2.7
pg
.
shutdown
()
_unregister_process_group
(
pg
.
group_name
)
_unregister_process_group
(
pg
.
group_name
)
vllm/forward_context.py
View file @
f9c069c8
...
@@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
...
@@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@
dataclass
@
dataclass
class
DPMetadata
:
class
DPMetadata
:
max_tokens_across_dp_cpu
:
torch
.
Tensor
cu_tokens_across_dp_cpu
:
torch
.
Tensor
cu_tokens_across_dp_cpu
:
torch
.
Tensor
...
@@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
...
@@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.distributed.parallel_state
import
get_dp_group
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
get_dp_group
().
cpu_group
)
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
get_dp_group
().
cpu_group
)
max_tokens_across_dp_cpu
=
torch
.
max
(
num_tokens_tensor
)
cu_tokens_across_dp_cpu
=
torch
.
cumsum
(
num_tokens_tensor
,
dim
=
0
)
cu_tokens_across_dp_cpu
=
torch
.
cumsum
(
num_tokens_tensor
,
dim
=
0
)
dp_metadata
=
DPMetadata
(
cu_tokens_across_dp_cpu
)
dp_metadata
=
DPMetadata
(
max_tokens_across_dp_cpu
,
cu_tokens_across_dp_cpu
)
global
_forward_context
global
_forward_context
prev_context
=
_forward_context
prev_context
=
_forward_context
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
f9c069c8
...
@@ -38,8 +38,8 @@ if HAS_TRITON:
...
@@ -38,8 +38,8 @@ if HAS_TRITON:
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp4
,
cutlass_moe_fp8
)
cutlass_moe_fp4
,
cutlass_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
TritonExperts
,
fused_experts
,
fused_moe
,
fused_topk
,
grouped_topk
)
get_config_file_name
,
grouped_topk
)
__all__
+=
[
__all__
+=
[
"fused_moe"
,
"fused_moe"
,
...
@@ -49,4 +49,5 @@ if HAS_TRITON:
...
@@ -49,4 +49,5 @@ if HAS_TRITON:
"grouped_topk"
,
"grouped_topk"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp8"
,
"cutlass_moe_fp4"
,
"cutlass_moe_fp4"
,
"TritonExperts"
,
]
]
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
f9c069c8
...
@@ -5,10 +5,176 @@ from typing import Optional
...
@@ -5,10 +5,176 @@ from typing import Optional
import
torch
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_perm
,
_resize_cache
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
class
CutlassExpertsFp8
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
ab_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
):
super
().
__init__
()
self
.
ab_strides1
=
ab_strides1
self
.
c_strides1
=
c_strides1
self
.
ab_strides2
=
ab_strides2
self
.
c_strides2
=
c_strides2
self
.
out_dtype
=
out_dtype
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
num_experts
:
int
,
)
->
tuple
[
int
,
int
,
torch
.
dtype
]:
# Note that K, N are transposed
N
,
K
=
K
,
N
workspace1
=
M
*
topk
*
max
(
2
*
N
,
K
)
workspace2
=
M
*
topk
*
N
return
(
workspace1
,
workspace2
,
self
.
out_dtype
)
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
a1q
=
hidden_states
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
assert
a1q
.
shape
[
1
]
==
w1
.
shape
[
1
],
"Hidden size mismatch w1"
assert
w1
.
shape
[
2
]
==
w2
.
shape
[
1
]
*
2
,
"Hidden size mismatch w2"
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Expert number mismatch"
assert
a1q_scale
is
None
or
a1q_scale
.
dim
(
)
==
0
or
a1q_scale
.
shape
[
0
]
==
1
or
a1q_scale
.
shape
[
0
]
==
a1q
.
shape
[
0
],
"Input scale shape mismatch"
assert
w1_scale
.
dim
()
==
1
or
w1_scale
.
shape
[
1
]
==
1
or
w1_scale
.
shape
[
1
]
==
w1
.
shape
[
2
],
"W1 scale shape mismatch"
assert
w2_scale
.
dim
()
==
1
or
w2_scale
.
shape
[
1
]
==
1
or
w2_scale
.
shape
[
1
]
==
w2
.
shape
[
2
],
"W2 scale shape mismatch"
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Weights expert number mismatch"
assert
w1
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a2_scale
is
None
or
a1q_scale
is
None
or
a2_scale
.
shape
==
a1q_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
assert
self
.
ab_strides1
.
shape
[
0
]
==
w1
.
shape
[
0
],
"AB Strides 1 expert number mismatch"
assert
self
.
c_strides1
.
shape
[
0
]
==
w1
.
shape
[
0
],
"C Strides 1 expert number mismatch"
assert
self
.
ab_strides2
.
shape
[
0
]
==
w2
.
shape
[
0
],
"AB Strides 2 expert number mismatch"
assert
self
.
c_strides2
.
shape
[
0
]
==
w2
.
shape
[
0
],
"C Strides 2 expert number mismatch"
assert
self
.
out_dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
M
=
a1q
.
shape
[
0
]
_
,
N
,
K
=
w2
.
shape
# because w1 + w2 are transposed
device
=
a1q
.
device
assert
w1
.
shape
[
1
]
==
K
assert
global_num_experts
!=
-
1
assert
a1q_scale
is
not
None
if
expert_map
is
not
None
:
"Translate info from expert_map to topk_ids"
local_topk_ids
=
torch
.
where
(
expert_map
[
topk_ids
]
!=
-
1
,
expert_map
[
topk_ids
],
-
1
)
else
:
local_topk_ids
=
topk_ids
topk
=
local_topk_ids
.
shape
[
1
]
per_act_token
=
a1q_scale
.
numel
()
!=
1
if
a1q_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
expert_offsets
=
torch
.
empty
((
global_num_experts
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes1
=
torch
.
empty
((
global_num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes2
=
torch
.
empty
((
global_num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
if
expert_map
is
not
None
:
a_map
=
torch
.
zeros
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
else
:
a_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
ops
.
get_cutlass_moe_mm_data
(
local_topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
global_num_experts
,
N
,
K
)
a1q
=
_fp8_perm
(
a1q
,
a_map
)
a1q_scale
=
a1q_scale
[
a_map
]
if
per_act_token
else
a1q_scale
c1
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
N
*
2
))
c2
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
N
))
c3
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
K
))
ops
.
cutlass_moe_mm
(
c1
,
a1q
,
w1
,
a1q_scale
,
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
self
.
ab_strides1
,
self
.
ab_strides1
,
self
.
c_strides1
)
self
.
activation
(
activation
,
c2
,
c1
)
a2q
,
a2q_scale
=
ops
.
scaled_fp8_quant
(
c2
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
if
expert_map
is
not
None
:
c3
.
fill_
(
0
)
ops
.
cutlass_moe_mm
(
c3
,
a2q
,
w2
,
a2q_scale
,
w2_scale
,
expert_offsets
[:
-
1
],
problem_sizes2
,
self
.
ab_strides2
,
self
.
ab_strides2
,
self
.
c_strides2
)
c3
=
c3
[
c_map
]
return
c3
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def
cutlass_moe_fp8
(
def
cutlass_moe_fp8
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
...
@@ -17,7 +183,7 @@ def cutlass_moe_fp8(
...
@@ -17,7 +183,7 @@ def cutlass_moe_fp8(
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
_
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
...
@@ -59,7 +225,7 @@ def cutlass_moe_fp8(
...
@@ -59,7 +225,7 @@ def cutlass_moe_fp8(
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Shape: scalar or [M]
- out_dtype (torch.
Tensor
): The output tensor type.
- out_dtype (torch.
dtype
): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
mapping from global expert-id to local expert-id. When expert_map[i]
...
@@ -71,115 +237,36 @@ def cutlass_moe_fp8(
...
@@ -71,115 +237,36 @@ def cutlass_moe_fp8(
Returns:
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
"""
assert
topk_weights
.
shape
==
topk_ids_
.
shape
,
"topk shape mismatch"
assert
w1_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_q
.
dtype
==
torch
.
float8_e4m3fn
assert
a
.
shape
[
1
]
==
w1_q
.
shape
[
1
],
"Hidden size mismatch w1"
assert
w1_q
.
shape
[
2
]
==
w2_q
.
shape
[
1
]
*
2
,
"Hidden size mismatch w2"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
a
.
shape
[
0
],
"Input scale shape mismatch"
assert
w1_scale
.
dim
()
==
1
or
w1_scale
.
shape
[
1
]
==
1
or
w1_scale
.
shape
[
1
]
==
w1_q
.
shape
[
2
],
"W1 scale shape mismatch"
assert
w2_scale
.
dim
()
==
1
or
w2_scale
.
shape
[
1
]
==
1
or
w2_scale
.
shape
[
1
]
==
w2_q
.
shape
[
2
],
"W2 scale shape mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Weights expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
assert
ab_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"AB Strides 1 expert number mismatch"
assert
c_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"C Strides 1 expert number mismatch"
assert
ab_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"AB Strides 2 expert number mismatch"
assert
c_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"C Strides 2 expert number mismatch"
assert
out_dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
k
=
w1_q
.
size
(
1
)
n
=
w2_q
.
size
(
1
)
local_topk_ids
=
topk_ids_
if
expert_map
is
not
None
:
"Translate info from expert_map to topk_ids"
local_topk_ids
=
torch
.
where
(
expert_map
[
topk_ids_
]
!=
-
1
,
expert_map
[
topk_ids_
],
-
1
)
topk
=
local_topk_ids
.
size
(
1
)
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
if
apply_router_weight_on_input
:
assert
topk
==
1
,
\
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a
=
a
*
topk_weights
.
to
(
out_dtype
)
a_q
,
a1_scale
=
ops
.
scaled_fp8_quant
(
a
,
a1_scale
,
use_per_token_if_dynamic
=
per_act_token
)
device
=
a_q
.
device
expert_offsets
=
torch
.
empty
((
num_experts
+
1
),
fn
=
mk
.
FusedMoEModularKernel
(
dtype
=
torch
.
int32
,
MoEPrepareAndFinalizeNoEP
(
device
=
device
)
per_channel_quant
=
per_act_token
,
problem_sizes1
=
torch
.
empty
((
num_experts
,
3
),
quant_dtype
=
torch
.
float8_e4m3fn
,
dtype
=
torch
.
int32
,
),
device
=
device
)
CutlassExpertsFp8
(
problem_sizes2
=
torch
.
empty
((
num_experts
,
3
),
ab_strides1
,
dtype
=
torch
.
int32
,
c_strides1
,
device
=
device
)
ab_strides2
,
c_strides2
,
a_map_initializer
=
torch
.
empty
out_dtype
,
c2_initializer
=
torch
.
empty
),
if
expert_map
is
not
None
:
)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
return
fn
(
# zeros for correctness.
a
,
a_map_initializer
=
torch
.
zeros
w1_q
,
c2_initializer
=
torch
.
zeros
w2_q
,
topk_weights
,
a_map
=
a_map_initializer
((
local_topk_ids
.
numel
()),
topk_ids
,
dtype
=
torch
.
int32
,
expert_map
=
expert_map
,
device
=
device
)
w1_scale
=
w1_scale
,
c_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
w2_scale
=
w2_scale
,
dtype
=
torch
.
int32
,
a1_scale
=
a1_scale
,
device
=
device
)
a2_scale
=
a2_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
ops
.
get_cutlass_moe_mm_data
(
local_topk_ids
,
expert_offsets
,
problem_sizes1
,
)
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
k
)
rep_a_q
=
a_q
.
view
(
dtype
=
torch
.
uint8
)[
a_map
].
view
(
dtype
=
a_q
.
dtype
)
rep_a1_scales
=
a1_scale
[
a_map
]
if
per_act_token
else
a1_scale
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
c2_initializer
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
ops
.
cutlass_moe_mm
(
c1
,
rep_a_q
,
w1_q
,
rep_a1_scales
,
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
)
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
out_dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate
,
c1
)
intemediate_q
,
a2_scale
=
ops
.
scaled_fp8_quant
(
intermediate
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
ops
.
cutlass_moe_mm
(
c2
,
intemediate_q
,
w2_q
,
a2_scale
,
w2_scale
,
expert_offsets
[:
-
1
],
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
)
# Gather tokens
c2
=
c2
[
c_map
].
view
(
m
,
topk
,
k
)
if
not
apply_router_weight_on_input
:
c2
=
c2
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)
return
c2
.
sum
(
dim
=
1
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
f9c069c8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
functools
import
importlib.util
import
importlib.util
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_align_block_size
)
_moe_permute
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_perm
,
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
_fp8_quantize
,
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_quantize
,
_resize_cache
)
_resize_cache
)
from
vllm.utils
import
round_up
from
vllm.utils
import
round_up
...
@@ -19,6 +20,19 @@ logger = init_logger(__name__)
...
@@ -19,6 +20,19 @@ logger = init_logger(__name__)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
@
functools
.
cache
def
deep_gemm_block_shape
()
->
list
[
int
]:
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
block
=
dg
.
get_m_alignment_for_contiguous_layout
()
return
[
block
,
block
]
def
_valid_deep_gemm_shape
(
M
:
int
,
N
:
int
,
K
:
int
):
align
=
deep_gemm_block_shape
()[
0
]
return
align
<=
M
and
N
%
align
==
0
and
K
%
align
==
0
def
_valid_deep_gemm
(
hidden_states
:
torch
.
Tensor
,
def
_valid_deep_gemm
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
...
@@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
"""
if
not
has_deep_gemm
:
if
not
has_deep_gemm
:
logger
.
debug
(
"DeepGemm disabled: deep_gemm not available."
)
return
False
return
False
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
# Expert maps not supported yet.
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
logger
.
debug
(
"DeepGemm disabled: expert map NYI."
)
return
False
return
False
align
=
dg
.
get_m_alignment_for_contiguous_layout
()
M
=
hidden_states
.
size
(
0
)
M
=
hidden_states
.
shape
[
0
]
_
,
K
,
N
=
w2
.
size
()
_
,
K
,
N
=
w2
.
shape
if
not
_valid_deep_gemm_shape
(
M
,
N
,
K
):
logger
.
debug
(
"DeepGemm disabled: unalinged problem size."
)
return
False
# For now, disable DeepGemm for small N until better permute/unpermute
if
(
w1
.
dtype
!=
torch
.
float8_e4m3fn
or
w2
.
dtype
!=
torch
.
float8_e4m3fn
):
# ops are available.
logger
.
debug
(
"DeepGemm disabled: invalid weight dtype(s)."
)
if
N
<=
512
:
return
False
return
False
if
align
>
M
or
N
%
align
!=
0
or
K
%
align
!=
0
:
if
(
not
hidden_states
.
is_contiguous
()
or
not
w1
.
is_contiguous
()
or
not
w2
.
is_contiguous
()):
logger
.
debug
(
"DeepGemm disabled: weights or activations not contiguous."
)
return
False
return
False
return
(
hidden_states
.
is_contiguous
()
and
w1
.
is_contiguous
()
return
True
and
w2
.
is_contiguous
())
def
_moe_permute
(
class
DeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
curr_hidden_states
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
],
def
__init__
(
self
):
curr_topk_ids
:
torch
.
Tensor
,
super
().
__init__
()
self
.
block_shape
=
deep_gemm_block_shape
()
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
num_experts
:
int
,
)
->
tuple
[
int
,
int
,
torch
.
dtype
]:
block_m
=
self
.
block_shape
[
0
]
M_sum
=
(
M
*
topk
)
+
num_experts
*
(
block_m
-
1
)
M_sum
=
round_up
(
M_sum
,
block_m
)
workspace1
=
M_sum
*
max
(
N
*
2
,
K
)
workspace2
=
M_sum
*
N
return
(
workspace1
,
workspace2
,
a
.
dtype
)
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
expert_map
:
Optional
[
torch
.
Tensor
],
block_m
:
int
,
w1_scale
:
Optional
[
torch
.
Tensor
],
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
w2_scale
:
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
w1_zp
:
Optional
[
torch
.
Tensor
],
"""
w2_zp
:
Optional
[
torch
.
Tensor
],
Determine the sorted_token_ids, expert_ids for the given problem size.
a1q_scale
:
Optional
[
torch
.
Tensor
],
Permute the hidden states and scales according to `sorted_token_ids`.
a2_scale
:
Optional
[
torch
.
Tensor
],
"""
workspace13
:
torch
.
Tensor
,
top_k_num
=
curr_topk_ids
.
shape
[
1
]
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
import
deep_gemm
as
dg
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
a1q
=
hidden_states
_
,
N
,
K
=
w1
.
size
()
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
assert
global_num_experts
!=
-
1
moe_align_block_size
(
curr_topk_ids
,
assert
w2
.
size
(
1
)
==
K
block_m
,
a1q
,
a1q_scale
,
_
,
expert_ids
,
inv_perm
=
_moe_permute
(
a1q
,
a1q_scale
,
topk_ids
,
global_num_experts
,
global_num_experts
,
expert_map
,
expert_map
,
pad_sorted_ids
=
True
))
self
.
block_shape
[
0
],
)
inv_perm
:
Optional
[
torch
.
Tensor
]
=
None
# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum
=
a1q
.
size
(
0
)
workspace1
=
_resize_cache
(
workspace13
,
(
M_sum
,
N
))
workspace2
=
_resize_cache
(
workspace2
,
(
M_sum
,
N
//
2
))
workspace3
=
_resize_cache
(
workspace13
,
(
M_sum
,
K
))
num_tokens
=
top_k_num
*
tokens_in_chunk
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
sorted_token_ids
=
sorted_token_ids
.
clamp
(
max
=
num_tokens
-
1
)
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_ids
)
expert_ids
=
torch
.
repeat_interleave
(
expert_ids
,
block_m
,
dim
=
0
)
inv_perm
=
torch
.
argsort
(
sorted_token_ids
)[:
num_tokens
]
# Permute according to sorted token ids.
self
.
activation
(
activation
,
workspace2
,
workspace1
.
view
(
-
1
,
N
))
curr_hidden_states
=
_fp8_perm
(
curr_hidden_states
,
sorted_token_ids
//
top_k_num
)
if
a1q_scale
is
not
None
:
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
a1q_scale
=
a1q_scale
[
sorted_token_ids
//
top_k_num
]
return
(
curr_hidden_states
,
a
1
q_scale
,
sorted_token_ids
,
expert_ids
,
a2q
,
a
2
q_scale
=
_fp8_quantize
(
workspace2
,
a2_scale
,
False
,
inv_perm
)
self
.
block_shape
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
workspace3
,
expert_ids
)
def
_moe_unpermute_and_reduce
(
workspace3
=
workspace3
[
inv_perm
,
...]
out
:
torch
.
Tensor
,
curr_hidden
:
torch
.
Tensor
,
return
workspace3
inv_perm
:
Optional
[
torch
.
Tensor
],
topk_weight
:
torch
.
Tensor
,
)
->
None
:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M
,
topk
=
topk_weight
.
shape
K
=
curr_hidden
.
shape
[
1
]
curr_hidden
=
curr_hidden
[
inv_perm
,
...]
curr_hidden
=
curr_hidden
.
view
(
-
1
,
topk
,
K
)
curr_hidden
.
mul_
(
topk_weight
.
view
(
M
,
-
1
,
1
))
ops
.
moe_sum
(
curr_hidden
,
out
)
def
deep_gemm_moe_fp8
(
def
deep_gemm_moe_fp8
(
...
@@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
...
@@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
...
@@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
...
@@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
Returns:
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
"""
# Lazy import to avoid CUDA initialization problems.
fn
=
mk
.
FusedMoEModularKernel
(
import
deep_gemm
as
dg
MoEPrepareAndFinalizeNoEP
(
quant_dtype
=
torch
.
float8_e4m3fn
,
block_shape
=
deep_gemm_block_shape
()),
assert
expert_map
is
None
,
"Expert maps not supported yet"
DeepGemmExperts
(),
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
return
fn
(
hidden_states
,
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
w1
,
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
w2
,
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
topk_weights
,
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
topk_ids
,
assert
hidden_states
.
dtype
in
[
inplace
,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
activation
,
]
global_num_experts
,
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
expert_map
,
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
w1_scale
=
w1_scale
,
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Expert number mismatch"
w2_scale
=
w2_scale
,
assert
w1
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
a1_scale
=
a1_scale
,
assert
w1
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
a2_scale
=
a2_scale
,
assert
a1_scale
is
None
or
a1_scale
.
dim
(
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
)
0
]
==
hidden_states
.
shape
[
0
],
"Input scale shape mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
assert
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
block_m
=
dg
.
get_m_alignment_for_contiguous_layout
()
block_shape
=
[
block_m
,
block_m
]
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w1_scale
).
contiguous
()
w2_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w2_scale
).
contiguous
()
M_sum
=
topk_ids
.
numel
()
+
global_num_experts
*
(
block_m
-
1
)
M_sum
=
round_up
(
M_sum
,
block_m
)
num_chunks
=
(
num_tokens
//
CHUNK_SIZE
)
+
1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13
=
torch
.
empty
(
M_sum
*
max
(
N
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
workspace1
=
workspace13
[:
M_sum
*
N
].
view
(
M_sum
,
N
)
workspace2
=
torch
.
empty
((
M_sum
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
workspace3
=
workspace13
[:
M_sum
*
K
].
view
(
M_sum
,
K
)
for
chunk
in
range
(
num_chunks
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
(
qcurr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
_moe_permute
(
qcurr_hidden_states
,
a1q_scale
,
curr_topk_ids
,
global_num_experts
,
expert_map
,
block_m
)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
curr_M
=
sorted_token_ids
.
numel
()
workspace1
=
_resize_cache
(
workspace1
,
(
curr_M
,
N
))
workspace2
=
_resize_cache
(
workspace2
,
(
curr_M
,
N
//
2
))
workspace3
=
_resize_cache
(
workspace3
,
(
curr_M
,
K
))
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qcurr_hidden_states
,
a1q_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_ids
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
workspace2
,
workspace1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
workspace2
,
workspace1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qworkspace2
,
a2q_scale
=
_fp8_quantize
(
workspace2
,
a2_scale
,
block_shape
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qworkspace2
,
a2q_scale
),
(
w2
,
w2_scale
),
workspace3
,
expert_ids
)
_moe_unpermute_and_reduce
(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
workspace3
.
view
(
*
workspace3
.
shape
),
inv_perm
,
curr_topk_weights
)
return
out_hidden_states
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
0 → 100644
View file @
f9c069c8
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
f9c069c8
...
@@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
...
@@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
moe_align_block_size
)
from
vllm.model_executor.layers.
quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.
fused_moe.prepare_finalize
import
(
per_token_group_quant_fp8
)
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.
quantization.utils.int8_
utils
import
(
from
vllm.model_executor.layers.
fused_moe.
utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
_resize_cache
,
moe_kernel_quantize_input
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
or
use_int8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
shape
[
0
]
M
=
A
.
shape
[
0
]
num_tokens
=
M
*
top_k
num_tokens
=
M
*
top_k
...
@@ -855,6 +870,7 @@ def fused_topk(
...
@@ -855,6 +870,7 @@ def fused_topk(
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
indices_type
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
"Number of tokens mismatch"
)
...
@@ -865,9 +881,10 @@ def fused_topk(
...
@@ -865,9 +881,10 @@ def fused_topk(
topk
,
topk
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk_ids
=
torch
.
empty
(
M
,
topk
,
topk
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
if
indices_type
is
None
else
indices_type
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
token_expert_indices
=
torch
.
empty
(
M
,
token_expert_indices
=
torch
.
empty
(
M
,
topk
,
topk
,
...
@@ -962,6 +979,20 @@ def get_config_dtype_str(
...
@@ -962,6 +979,20 @@ def get_config_dtype_str(
return
None
return
None
# TODO (bnell): use scalar_type instead of bools?
def
get_config_qtype
(
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
)
->
Optional
[
torch
.
dtype
]:
if
use_fp8_w8a8
:
return
torch
.
float8_e4m3fn
elif
use_int8_w8a8
:
return
torch
.
int8
return
None
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
)
->
torch
.
Tensor
:
allow_deep_gemm
:
bool
=
False
)
->
torch
.
Tensor
:
if
(
allow_deep_gemm
and
use_fp8_w8a8
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N
=
w1
.
shape
[
1
]
if
(
allow_deep_gemm
and
use_fp8_w8a8
and
N
>
512
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)):
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)):
assert
apply_router_weight_on_input
is
False
assert
apply_router_weight_on_input
is
False
return
deep_gemm_moe_fp8
(
return
deep_gemm_moe_fp8
(
...
@@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
)
else
:
else
:
return
dispatch_fused_experts_func
(
inplace
)(
return
dispatch_fused_experts_func
(
inplace
)(
...
@@ -1171,60 +1206,8 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1171,60 +1206,8 @@ def fused_experts(hidden_states: torch.Tensor,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
def
moe_kernel_prepare_input
(
def
fused_experts_impl
(
A
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
,
use_per_token_if_dynamic
=
per_channel_quant
)
else
:
# activation block-wise fp8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
# activation channel-wise int8 quantization
assert
(
per_channel_quant
),
"int8 quantization only supports block or channel-wise"
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
# activation block-wise int8 quantization
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
return
A
,
A_scale
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
...
@@ -1245,13 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1245,13 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
):
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
# Check constraints.
# Check constraints.
if
use_int4_w4a16
:
if
use_int4_w4a16
:
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
2
],
"Hidden size mismatch"
2
],
"Hidden size mismatch"
else
:
else
:
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
(
f
"Hidden size mismatch
{
hidden_states
.
shape
[
1
]
}
!=
{
w1
.
shape
[
2
]
}
"
)
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
...
@@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
]
num_tokens
,
_
=
hidden_states
.
shape
num_tokens
=
hidden_states
.
shape
[
0
]
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
...
@@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
qtype
=
get_config_qtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
)
get_config_func
=
functools
.
partial
(
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
try_get_optimal_moe_config
,
w1
.
shape
,
w1
.
shape
,
...
@@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
qcurr_hidden_states
,
q
a1_scale
=
moe_kernel_
prepar
e_input
(
qcurr_hidden_states
,
a1
q
_scale
=
moe_kernel_
quantiz
e_input
(
A
=
curr_hidden_states
,
A
=
curr_hidden_states
,
B
=
w1
,
A_scale
=
a1_scale
,
A_scale
=
a1_scale
,
B_scale
=
w1_scale
,
qtype
=
qtype
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
...
@@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
w1
,
intermediate_cache1
,
intermediate_cache1
,
q
a1_scale
,
a1
q
_scale
,
w1_scale
,
w1_scale
,
w1_zp
,
w1_zp
,
curr_topk_weights
,
curr_topk_weights
,
...
@@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else
:
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
qintermediate_cache2
,
q
a2_scale
=
moe_kernel_
prepar
e_input
(
qintermediate_cache2
,
a2
q
_scale
=
moe_kernel_
quantiz
e_input
(
A
=
intermediate_cache2
,
A
=
intermediate_cache2
,
B
=
w2
,
A_scale
=
a2_scale
,
A_scale
=
a2_scale
,
B_scale
=
w2_scale
,
qtype
=
qtype
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
w2
,
intermediate_cache3
,
intermediate_cache3
,
q
a2_scale
,
a2
q
_scale
,
w2_scale
,
w2_scale
,
w2_zp
,
w2_zp
,
curr_topk_weights
,
curr_topk_weights
,
...
@@ -1534,3 +1514,209 @@ def fused_moe(
...
@@ -1534,3 +1514,209 @@ def fused_moe(
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
block_shape
=
block_shape
)
class
TritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_m
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
use_fp8_w8a8
=
use_fp8_w8a8
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
block_shape
=
block_shape
self
.
block_m
=
block_m
self
.
qtype
=
get_config_qtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
)
self
.
per_channel_quant
=
per_channel_quant
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
num_experts
:
int
,
)
->
tuple
[
int
,
int
,
torch
.
dtype
]:
factor
=
num_experts
if
a
.
dim
()
==
3
else
1
workspace1
=
M
*
topk
*
max
(
N
*
2
,
K
)
*
factor
workspace2
=
M
*
topk
*
N
*
factor
return
(
workspace1
,
workspace2
,
a
.
dtype
)
def
apply
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Check constraints.
if
self
.
use_int4_w4a16
:
assert
hidden_states
.
size
(
-
1
)
//
2
==
w1
.
size
(
2
),
(
"Hidden size mismatch"
)
else
:
assert
hidden_states
.
size
(
-
1
)
==
w1
.
size
(
2
),
\
(
f
"Hidden size mismatch
{
hidden_states
.
size
(
-
1
)
}
"
f
"!=
{
w1
.
size
(
2
)
}
"
)
assert
hidden_states
.
is_contiguous
(
),
"Hidden_states must be contiguous"
assert
hidden_states
.
dim
()
==
2
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float8_e4m3fn
]
E
,
num_tokens
,
N
,
K
,
top_k_num
=
mk
.
_moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
config
=
try_get_optimal_moe_config
(
w1
.
shape
,
w2
.
shape
,
top_k_num
,
config_dtype
,
num_tokens
,
block_shape
=
self
.
block_shape
,
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
:
compute_type
=
tl
.
bfloat16
elif
hidden_states
.
dtype
==
torch
.
float16
:
compute_type
=
tl
.
float16
elif
hidden_states
.
dtype
==
torch
.
float32
:
compute_type
=
tl
.
float32
elif
hidden_states
.
dtype
==
torch
.
float8_e4m3fn
:
compute_type
=
tl
.
bfloat16
else
:
raise
ValueError
(
f
"Unsupported compute_type:
{
hidden_states
.
dtype
}
"
)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1
=
_resize_cache
(
workspace13
,
(
num_tokens
,
top_k_num
,
N
))
intermediate_cache2
=
_resize_cache
(
workspace2
,
(
num_tokens
*
top_k_num
,
N
//
2
))
intermediate_cache3
=
_resize_cache
(
workspace13
,
(
num_tokens
,
top_k_num
,
K
))
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1q_scale
,
w1_scale
,
w1_zp
,
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
top_k_num
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
per_channel_quant
=
self
.
per_channel_quant
,
block_shape
=
self
.
block_shape
)
self
.
activation
(
activation
,
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qintermediate_cache2
,
a2q_scale
=
moe_kernel_quantize_input
(
intermediate_cache2
,
a2_scale
,
self
.
qtype
,
self
.
per_channel_quant
,
self
.
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2q_scale
,
w2_scale
,
w2_zp
,
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
per_channel_quant
=
self
.
per_channel_quant
,
block_shape
=
self
.
block_shape
)
return
intermediate_cache3
def
modular_triton_fused_moe
(
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
mk
.
FusedMoEModularKernel
:
qtype
=
get_config_qtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
)
return
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
quant_dtype
=
qtype
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
),
TritonExperts
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
),
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment