Unverified Commit f9c069c8 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

Modularize fused experts and integrate PPLX kernels (#15956)

parent 418d2f8b
......@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
......
......@@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
......@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
......@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>());
});
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
......@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
......@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
......
......@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}
template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
{
using cub_kvp = cub::KeyValuePair<int, float>;
......@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
......@@ -397,8 +405,8 @@ struct TopkConstants
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
......@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
template <typename IndType>
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
IndType* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
......@@ -493,6 +502,9 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
if(topk_indices.scalar_type() == at::ScalarType::Int)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
......@@ -503,4 +515,19 @@ void topk_softmax(
num_experts,
topk,
stream);
}
else
{
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
}
......@@ -65,11 +65,17 @@ def parse_args():
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args()
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
......@@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
max_tokens=[16, 20][global_dp_rank % 2])
# Create an LLM.
llm = LLM(model=model,
llm = LLM(
model=model,
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True,
enable_expert_parallel=True)
enforce_eager=enforce_eager,
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for i, output in enumerate(outputs):
......@@ -155,7 +164,8 @@ if __name__ == "__main__":
proc = Process(target=main,
args=(args.model, dp_size, local_dp_rank,
global_dp_rank, dp_master_ip, dp_master_port,
tp_size))
tp_size, args.enforce_eager,
args.trust_remote_code))
proc.start()
procs.append(proc)
exit_code = 0
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import pytest
import torch
import triton.language as tl
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
@dataclass
class BatchedMMConfig:
dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
N: int
@dataclass
class BatchedMMTensors:
A: torch.Tensor # [E, max_tokens, K]
B: torch.Tensor # [E, K, N] - column major
C: torch.Tensor # [E, max_tokens, N]
num_expert_tokens: torch.Tensor # [E]
@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)
test_output = tensors.C
ref_output = test_output.clone()
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
test_output,
tensors.num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
None,
# Quantization schemes
False,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
......@@ -30,6 +30,11 @@ MNK_FACTORS = [
(224, 3072, 1536),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclasses.dataclass
class MOETensors:
......@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids_': topk_ids,
'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2,
......@@ -231,15 +236,12 @@ def test_cutlass_moe_8_bit_no_graph(
per_out_ch: bool,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a,
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
......@@ -276,17 +278,14 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_out_ch: bool,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
with set_current_vllm_config(vllm_config):
dtype = torch.half
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(mt.a,
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
......@@ -334,15 +333,12 @@ def test_cutlass_moe_8_bit_EP(
ep_size: int,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a,
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
......
......@@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
......@@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
......@@ -70,6 +75,7 @@ def test_fused_moe(
else:
e_map = None
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
w1,
......@@ -95,6 +101,7 @@ def test_fused_moe(
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
......@@ -115,7 +122,6 @@ def test_fused_moe(
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
ep_size: int, dtype: torch.dtype, group_size: int,
has_zp: bool, weight_bits: int):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
......@@ -194,6 +200,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else:
e_map = None
with set_current_vllm_config(vllm_config):
triton_output = fused_moe(a,
w1_qweight,
w2_qweight,
......@@ -210,6 +217,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
......@@ -515,6 +523,7 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
......
This diff is collapsed.
......@@ -7,6 +7,7 @@ import pytest
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.platforms import current_platform
......@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input
......@@ -137,6 +142,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
out = fused_moe(
a,
......
......@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
deep_gemm_moe_fp8)
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
......@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
......@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam.
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
......@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
......@@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_size = [block_m, block_m]
dtype = torch.bfloat16
# only aligned sizes
if (N % block_m != 0 or K % block_m != 0 or topk > E):
pytest.skip(
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
if N <= 512:
pytest.skip("Skipping N <= 512 until performance issues solved.")
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
vllm_config = VllmConfig()
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
......
......@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# For test
def native_per_token_group_quant_int8(x,
......@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam.
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
......
......@@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
"""
import contextlib
import gc
import importlib.util
import pickle
import weakref
from collections import namedtuple
......@@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op)
run_once, supports_custom_op)
@dataclass
......@@ -936,9 +937,49 @@ def init_distributed_environment(
"world group already initialized with a different world size")
PPLX_DID_INIT: bool = False
@run_once
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
@run_once
def pplx_finalize():
global PPLX_DID_INIT
if PPLX_DID_INIT:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""
......@@ -1041,10 +1082,14 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group)
if enable_expert_parallel:
pplx_init(rank, world_size)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
......@@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend)
pipeline_model_parallel_size,
enable_expert_parallel, backend)
return
assert (
......@@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
pplx_finalize()
if _TP:
_TP.destroy()
_TP = None
......
......@@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_tcp_uri
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
logger = init_logger(__name__)
......@@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
"""
if is_torch_equal_or_newer("2.7"):
pg.shutdown()
else:
# Lazy import for non-CUDA backends.
try:
# pytorch <= 2.6
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
except ImportError:
# pytorch >= 2.7
pg.shutdown()
_unregister_process_group(pg.group_name)
......@@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass
class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor
......@@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
cu_tokens_across_dp_cpu)
global _forward_context
prev_context = _forward_context
......
......@@ -38,8 +38,8 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4, cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
TritonExperts, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
__all__ += [
"fused_moe",
......@@ -49,4 +49,5 @@ if HAS_TRITON:
"grouped_topk",
"cutlass_moe_fp8",
"cutlass_moe_fp4",
"TritonExperts",
]
......@@ -5,10 +5,176 @@ from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
from vllm.scalar_type import scalar_types
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.ab_strides1 = ab_strides1
self.c_strides1 = c_strides1
self.ab_strides2 = ab_strides2
self.c_strides2 = c_strides2
self.out_dtype = out_dtype
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
# Note that K, N are transposed
N, K = K, N
workspace1 = M * topk * max(2 * N, K)
workspace2 = M * topk * N
return (workspace1, workspace2, self.out_dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
a1q = hidden_states
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[2], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert self.ab_strides1.shape[0] == w1.shape[
0], "AB Strides 1 expert number mismatch"
assert self.c_strides1.shape[0] == w1.shape[
0], "C Strides 1 expert number mismatch"
assert self.ab_strides2.shape[0] == w2.shape[
0], "AB Strides 2 expert number mismatch"
assert self.c_strides2.shape[0] == w2.shape[
0], "C Strides 2 expert number mismatch"
assert self.out_dtype in [torch.half,
torch.bfloat16], "Invalid output dtype"
M = a1q.shape[0]
_, N, K = w2.shape # because w1 + w2 are transposed
device = a1q.device
assert w1.shape[1] == K
assert global_num_experts != -1
assert a1q_scale is not None
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
else:
local_topk_ids = topk_ids
topk = local_topk_ids.shape[1]
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
if expert_map is not None:
a_map = torch.zeros((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
else:
a_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
expert_offsets[:-1], problem_sizes1,
self.ab_strides1, self.ab_strides1, self.c_strides1)
self.activation(activation, c2, c1)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
if expert_map is not None:
c3.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
expert_offsets[:-1], problem_sizes2,
self.ab_strides2, self.ab_strides2, self.c_strides2)
c3 = c3[c_map]
return c3
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8(
a: torch.Tensor,
......@@ -17,7 +183,7 @@ def cutlass_moe_fp8(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
......@@ -59,7 +225,7 @@ def cutlass_moe_fp8(
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
- out_dtype (torch.dtype): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
......@@ -71,115 +237,36 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1_q.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2_q.shape[2], "W2 scale shape mismatch"
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert ab_strides1.shape[0] == w1_q.shape[
0], "AB Strides 1 expert number mismatch"
assert c_strides1.shape[0] == w1_q.shape[
0], "C Strides 1 expert number mismatch"
assert ab_strides2.shape[0] == w2_q.shape[
0], "AB Strides 2 expert number mismatch"
assert c_strides2.shape[0] == w2_q.shape[
0], "C Strides 2 expert number mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
local_topk_ids = topk_ids_
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
expert_map[topk_ids_], -1)
topk = local_topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
if apply_router_weight_on_input:
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a = a * topk_weights.to(out_dtype)
a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token)
device = a_q.device
expert_offsets = torch.empty((num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map_initializer = torch.empty
c2_initializer = torch.empty
if expert_map is not None:
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
a_map_initializer = torch.zeros
c2_initializer = torch.zeros
a_map = a_map_initializer((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n,
k)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1,
ab_strides1, c_strides1)
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
intemediate_q, a2_scale = ops.scaled_fp8_quant(
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
# Gather tokens
c2 = c2[c_map].view(m, topk, k)
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
per_channel_quant=per_act_token,
quant_dtype=torch.float8_e4m3fn,
),
CutlassExpertsFp8(
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
out_dtype,
),
)
return fn(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
......
# SPDX-License-Identifier: Apache-2.0
import functools
import importlib.util
from typing import Optional
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
from vllm.utils import round_up
......@@ -19,6 +20,19 @@ logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@functools.cache
def deep_gemm_block_shape() -> list[int]:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
block = dg.get_m_alignment_for_contiguous_layout()
return [block, block]
def _valid_deep_gemm_shape(M: int, N: int, K: int):
align = deep_gemm_block_shape()[0]
return align <= M and N % align == 0 and K % align == 0
def _valid_deep_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if not has_deep_gemm:
logger.debug("DeepGemm disabled: deep_gemm not available.")
return False
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
# Expert maps not supported yet.
if expert_map is not None:
logger.debug("DeepGemm disabled: expert map NYI.")
return False
align = dg.get_m_alignment_for_contiguous_layout()
M = hidden_states.shape[0]
_, K, N = w2.shape
M = hidden_states.size(0)
_, K, N = w2.size()
if not _valid_deep_gemm_shape(M, N, K):
logger.debug("DeepGemm disabled: unalinged problem size.")
return False
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if N <= 512:
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
logger.debug("DeepGemm disabled: invalid weight dtype(s).")
return False
if align > M or N % align != 0 or K % align != 0:
if (not hidden_states.is_contiguous() or not w1.is_contiguous()
or not w2.is_contiguous()):
logger.debug(
"DeepGemm disabled: weights or activations not contiguous.")
return False
return (hidden_states.is_contiguous() and w1.is_contiguous()
and w2.is_contiguous())
return True
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
curr_topk_ids: torch.Tensor,
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self):
super().__init__()
self.block_shape = deep_gemm_block_shape()
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
workspace1 = M_sum * max(N * 2, K)
workspace2 = M_sum * N
return (workspace1, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.shape[1]
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
import deep_gemm as dg
tokens_in_chunk, _ = curr_hidden_states.shape
a1q = hidden_states
_, N, K = w1.size()
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids,
block_m,
assert global_num_experts != -1
assert w2.size(1) == K
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
a1q,
a1q_scale,
topk_ids,
global_num_experts,
expert_map,
pad_sorted_ids=True))
self.block_shape[0],
)
inv_perm: Optional[torch.Tensor] = None
# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0)
workspace1 = _resize_cache(workspace13, (M_sum, N))
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
workspace3 = _resize_cache(workspace13, (M_sum, K))
num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
# Permute according to sorted token ids.
curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num)
self.activation(activation, workspace2, workspace1.view(-1, N))
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
a2q_scale: Optional[torch.Tensor] = None
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm)
a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False,
self.block_shape)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: Optional[torch.Tensor],
topk_weight: torch.Tensor,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.shape
K = curr_hidden.shape[1]
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
workspace3 = workspace3[inv_perm, ...]
return workspace3
def deep_gemm_moe_fp8(
......@@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
expert_map: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input=False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
......@@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
assert expert_map is None, "Expert maps not supported yet"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
0] == hidden_states.shape[0], "Input scale shape mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
global_num_experts = E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
block_m = dg.get_m_alignment_for_contiguous_layout()
block_shape = [block_m, block_m]
assert w1_scale is not None
assert w2_scale is not None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
num_chunks = (num_tokens // CHUNK_SIZE) + 1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.empty(M_sum * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
workspace2 = torch.empty((M_sum, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
for chunk in range(num_chunks):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
a1_scale, block_shape)
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
curr_topk_ids, global_num_experts,
expert_map, block_m)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
curr_M = sorted_token_ids.numel()
workspace1 = _resize_cache(workspace1, (curr_M, N))
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
workspace3 = _resize_cache(workspace3, (curr_M, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
expert_ids)
if activation == "silu":
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
block_shape)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
_moe_unpermute_and_reduce(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
return out_hidden_states
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
block_shape=deep_gemm_block_shape()),
DeepGemmExperts(),
)
return fn(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace,
activation,
global_num_experts,
expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
This diff is collapsed.
......@@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
......@@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.shape[0]
num_tokens = M * top_k
......@@ -855,6 +870,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: Optional[torch.dtype] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
......@@ -865,9 +881,10 @@ def fused_topk(
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device)
token_expert_indices = torch.empty(M,
topk,
......@@ -962,6 +979,20 @@ def get_config_dtype_str(
return None
# TODO (bnell): use scalar_type instead of bools?
def get_config_qtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
) -> Optional[torch.dtype]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
return None
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.shape[1]
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
......@@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return dispatch_fused_experts_func(inplace)(
......@@ -1171,60 +1206,8 @@ def fused_experts(hidden_states: torch.Tensor,
block_shape=block_shape)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def fused_experts_impl(hidden_states: torch.Tensor,
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
......@@ -1245,13 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None):
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
else:
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
......@@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
torch.float32, torch.float16, torch.bfloat16
]
num_tokens, _ = hidden_states.shape
num_tokens = hidden_states.shape[0]
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
......@@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype)
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
......@@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
B=w1,
A_scale=a1_scale,
B_scale=w1_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
qtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
......@@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
qa1_scale,
a1q_scale,
w1_scale,
w1_zp,
curr_topk_weights,
......@@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
B=w2,
A_scale=a2_scale,
B_scale=w2_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
qtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
qa2_scale,
a2q_scale,
w2_scale,
w2_zp,
curr_topk_weights,
......@@ -1534,3 +1514,209 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):
super().__init__()
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.block_m = block_m
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
self.per_channel_quant = per_channel_quant
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
factor = num_experts if a.dim() == 3 else 1
workspace1 = M * topk * max(N * 2, K) * factor
workspace2 = M * topk * N * factor
return (workspace1, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
assert hidden_states.size(-1) == w1.size(2), \
(f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)
if global_num_experts == -1:
global_num_experts = E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.shape,
w2.shape,
top_k_num,
config_dtype,
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
compute_type = tl.bfloat16
else:
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13,
(num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2,
(num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
self.block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
return intermediate_cache3
def modular_triton_fused_moe(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> mk.FusedMoEModularKernel:
qtype = get_config_qtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
)
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
quant_dtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment