Unverified Commit 29fa5cac authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent b2d9be6f
...@@ -29,6 +29,7 @@ MNK_FACTORS = [ ...@@ -29,6 +29,7 @@ MNK_FACTORS = [
(224, 1024, 1536), (224, 1024, 1536),
(224, 3072, 1024), (224, 3072, 1024),
(224, 3072, 1536), (224, 3072, 1536),
(1024 * 128, 1024, 1024),
] ]
vllm_config = VllmConfig(parallel_config=ParallelConfig( vllm_config = VllmConfig(parallel_config=ParallelConfig(
......
...@@ -15,7 +15,8 @@ import vllm.model_executor.layers.fused_moe # noqa ...@@ -15,7 +15,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe) fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
...@@ -76,6 +77,13 @@ def test_fused_moe( ...@@ -76,6 +77,13 @@ def test_fused_moe(
else: else:
e_map = None e_map = None
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=None)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map) torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a, iterative_output = iterative_moe(a,
...@@ -103,7 +111,20 @@ def test_fused_moe( ...@@ -103,7 +111,20 @@ def test_fused_moe(
expert_map=e_map, expert_map=e_map,
renormalize=False) renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_triton_output = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(m_triton_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(iterative_output, torch.testing.assert_close(iterative_output,
torch_output, torch_output,
atol=2e-2, atol=2e-2,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest import pytest
import torch import torch
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -14,6 +15,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -14,6 +15,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .deepep_utils import ProcessGroupInfo, parallel_launch
try: try:
from pplx_kernels import AllToAll from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
...@@ -64,6 +67,7 @@ def pplx_cutlass_moe( ...@@ -64,6 +67,7 @@ def pplx_cutlass_moe(
out_dtype, out_dtype,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
group_name: Optional[str],
): ):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize)
...@@ -84,7 +88,7 @@ def pplx_cutlass_moe( ...@@ -84,7 +88,7 @@ def pplx_cutlass_moe(
else: else:
scale_elems = (hidden_dim + block_size - 1) // block_size scale_elems = (hidden_dim + block_size - 1) // block_size
ata = AllToAll.internode( args = dict(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_experts=num_experts, num_experts=num_experts,
experts_per_token=topk, experts_per_token=topk,
...@@ -96,6 +100,12 @@ def pplx_cutlass_moe( ...@@ -96,6 +100,12 @@ def pplx_cutlass_moe(
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize, hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
) )
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
w1 = w1.to(device) w1 = w1.to(device)
w2 = w2.to(device) w2 = w2.to(device)
w1_scale = w1_scale.to(device) w1_scale = w1_scale.to(device)
...@@ -113,7 +123,10 @@ def pplx_cutlass_moe( ...@@ -113,7 +123,10 @@ def pplx_cutlass_moe(
) )
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
out_dtype, per_act_token, per_out_ch) out_dtype,
per_act_token,
per_out_ch,
use_batched_format=True)
fused_cutlass_experts = FusedMoEModularKernel( fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
...@@ -184,11 +197,17 @@ def _pplx_moe( ...@@ -184,11 +197,17 @@ def _pplx_moe(
w2_full: torch.Tensor, w2_full: torch.Tensor,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
use_internode: bool,
): ):
if use_internode:
uid = nvshmem_get_unique_id( uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0) torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size) nvshmem_init(uid, pgi.rank, pgi.world_size)
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
...@@ -196,7 +215,7 @@ def _pplx_moe( ...@@ -196,7 +215,7 @@ def _pplx_moe(
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids, w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token, a1_scale, out_dtype, per_act_token,
per_out_ch) per_out_ch, group_name)
torch_output = chunk_by_rank(torch_output, pgi.rank, torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device) pgi.world_size).to(pplx_output.device)
...@@ -207,6 +226,7 @@ def _pplx_moe( ...@@ -207,6 +226,7 @@ def _pplx_moe(
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
if use_internode:
nvshmem_finalize() nvshmem_finalize()
...@@ -218,6 +238,7 @@ def _pplx_moe( ...@@ -218,6 +238,7 @@ def _pplx_moe(
@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.skipif( @pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()), current_platform.get_device_capability()),
...@@ -232,6 +253,7 @@ def test_cutlass_moe_pplx( ...@@ -232,6 +253,7 @@ def test_cutlass_moe_pplx(
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_internode: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
...@@ -284,4 +306,5 @@ def test_cutlass_moe_pplx( ...@@ -284,4 +306,5 @@ def test_cutlass_moe_pplx(
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q, parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1, w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch) dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
use_internode)
...@@ -18,7 +18,6 @@ try: ...@@ -18,7 +18,6 @@ try:
except ImportError: except ImportError:
has_pplx = False has_pplx = False
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe import override_config
...@@ -30,6 +29,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -30,6 +29,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .deepep_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif( requires_pplx = pytest.mark.skipif(
not has_pplx, not has_pplx,
reason="Requires PPLX kernels", reason="Requires PPLX kernels",
...@@ -153,7 +154,10 @@ def batched_moe( ...@@ -153,7 +154,10 @@ def batched_moe(
num_experts = w1.shape[0] num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0), BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1,
rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
...@@ -229,9 +233,15 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: ...@@ -229,9 +233,15 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
return t[(r * chunk):(r + 1) * chunk] return t[(r * chunk):(r + 1) * chunk]
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, def pplx_prepare_finalize(
topk_weight: torch.Tensor, topk_ids: torch.Tensor, pgi: ProcessGroupInfo,
num_experts: int) -> torch.Tensor: dp_size: int,
a: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
group_name: Optional[str],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize)
...@@ -245,7 +255,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, ...@@ -245,7 +255,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
world_size = pgi.world_size world_size = pgi.world_size
max_num_tokens = rank_chunk(num_tokens, 0, world_size) max_num_tokens = rank_chunk(num_tokens, 0, world_size)
ata = AllToAll.internode( args = dict(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_experts=num_experts, num_experts=num_experts,
experts_per_token=topk, experts_per_token=topk,
...@@ -259,6 +269,12 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, ...@@ -259,6 +269,12 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
torch.float32.itemsize)), torch.float32.itemsize)),
) )
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
topk_ids = topk_ids.to(dtype=torch.uint32) topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize( prepare_finalize = PplxPrepareAndFinalize(
...@@ -318,11 +334,19 @@ def _pplx_prepare_finalize( ...@@ -318,11 +334,19 @@ def _pplx_prepare_finalize(
score: torch.Tensor, score: torch.Tensor,
topk: torch.Tensor, topk: torch.Tensor,
num_experts: int, num_experts: int,
use_internode: bool,
): ):
if use_internode:
uid = nvshmem_get_unique_id( uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0) torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size) nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
device = pgi.device device = pgi.device
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
...@@ -335,13 +359,14 @@ def _pplx_prepare_finalize( ...@@ -335,13 +359,14 @@ def _pplx_prepare_finalize(
a.dtype) a.dtype)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
num_experts) num_experts, group_name)
torch_output = chunk_by_rank(torch_output, pgi.rank, torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device) pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
if use_internode:
nvshmem_finalize() nvshmem_finalize()
...@@ -353,6 +378,7 @@ def _pplx_prepare_finalize( ...@@ -353,6 +378,7 @@ def _pplx_prepare_finalize(
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx @requires_pplx
def test_pplx_prepare_finalize( def test_pplx_prepare_finalize(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
...@@ -360,6 +386,7 @@ def test_pplx_prepare_finalize( ...@@ -360,6 +386,7 @@ def test_pplx_prepare_finalize(
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_internode: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
m, n, k = mnk m, n, k = mnk
...@@ -369,10 +396,11 @@ def test_pplx_prepare_finalize( ...@@ -369,10 +396,11 @@ def test_pplx_prepare_finalize(
score = torch.randn((m, e), device=device, dtype=dtype) score = torch.randn((m, e), device=device, dtype=dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e) topk, e, use_internode)
def pplx_moe( def pplx_moe(
group_name: Optional[str],
rank: int, rank: int,
world_size: int, world_size: int,
dp_size: int, dp_size: int,
...@@ -394,7 +422,7 @@ def pplx_moe( ...@@ -394,7 +422,7 @@ def pplx_moe(
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size) max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
ata = AllToAll.internode( args = dict(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_experts=num_experts, num_experts=num_experts,
experts_per_token=topk, experts_per_token=topk,
...@@ -408,6 +436,12 @@ def pplx_moe( ...@@ -408,6 +436,12 @@ def pplx_moe(
torch.float32.itemsize)), torch.float32.itemsize)),
) )
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
topk_ids = topk_ids.to(dtype=torch.uint32) topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize( prepare_finalize = PplxPrepareAndFinalize(
...@@ -522,11 +556,18 @@ def _pplx_moe( ...@@ -522,11 +556,18 @@ def _pplx_moe(
w2: torch.Tensor, w2: torch.Tensor,
score: torch.Tensor, score: torch.Tensor,
topk: int, topk: int,
use_internode: bool,
): ):
if use_internode:
uid = nvshmem_get_unique_id( uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0) torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size) nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
m, k = a.shape m, k = a.shape
e, _, n = w2.shape e, _, n = w2.shape
...@@ -536,8 +577,8 @@ def _pplx_moe( ...@@ -536,8 +577,8 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config): with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2, pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
topk_weight, topk_ids) a, w1, w2, topk_weight, topk_ids)
# TODO (bnell): fix + re-enable # TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids) # topk_ids)
...@@ -548,6 +589,7 @@ def _pplx_moe( ...@@ -548,6 +589,7 @@ def _pplx_moe(
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
if use_internode:
nvshmem_finalize() nvshmem_finalize()
...@@ -556,6 +598,7 @@ def _pplx_moe( ...@@ -556,6 +598,7 @@ def _pplx_moe(
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx @requires_pplx
def test_pplx_moe( def test_pplx_moe(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
...@@ -563,6 +606,7 @@ def test_pplx_moe( ...@@ -563,6 +606,7 @@ def test_pplx_moe(
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_internode: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
m, n, k = mnk m, n, k = mnk
...@@ -572,4 +616,5 @@ def test_pplx_moe( ...@@ -572,4 +616,5 @@ def test_pplx_moe(
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
use_internode)
...@@ -13,7 +13,8 @@ from vllm.model_executor.layers.activation import SiluAndMul ...@@ -13,7 +13,8 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8) _valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size) moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -45,7 +46,7 @@ N = [128, 512, 7168, 7748, 13824] ...@@ -45,7 +46,7 @@ N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384] K = [256, 3884, 4096, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168. # and its hidden size is 7168.
M_moe = [1, 2, 7, 83, 128, 2048] M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128]
M_moe_dg = [128, 192, 1335, 2048] M_moe_dg = [128, 192, 1335, 2048]
N_moe = [128, 256, 1024, 4608] # [13824] N_moe = [128, 256, 1024, 4608] # [13824]
K_moe = [256, 512, 7168] # [13824] K_moe = [256, 512, 7168] # [13824]
...@@ -214,6 +215,13 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -214,6 +215,13 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
out = fused_moe( out = fused_moe(
...@@ -231,6 +239,16 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -231,6 +239,16 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size) block_size)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_out = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=E,
w1_scale=w1_s,
w2_scale=w2_s)
#print(f"{out.sum()=}") #print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}") #print(f"{ref_out.sum()=}")
...@@ -239,6 +257,11 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -239,6 +257,11 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
torch.mean(torch.abs(ref_out.to(torch.float32)))) torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03 assert rel_diff < 0.03
rel_diff = (torch.mean(
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
def per_block_cast_to_fp8( def per_block_cast_to_fp8(
x: torch.Tensor, x: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import traceback
from typing import Callable
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
...@@ -36,6 +36,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -36,6 +36,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert (len(self.block_shape) == 2 and all( assert (len(self.block_shape) == 2 and all(
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape])) [v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape]))
def supports_chunking(self) -> bool:
return False
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -45,17 +48,19 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -45,17 +48,19 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size num_dp = self.world_size // self.dp_size
max_num_tokens = a.size( max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens 0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
return (workspace13, workspace2, a.dtype) output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -72,7 +77,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -72,7 +77,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
import deep_gemm as dg import deep_gemm as dg
assert hidden_states.ndim == 3 assert hidden_states.ndim == 3
...@@ -89,7 +94,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -89,7 +94,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
workspace3 = _resize_cache(workspace13, (E, max_num_tokens, K))
# (from deepgemm docs) : A value hint (which is a value on CPU) # (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value # for the M expectation of each batch, correctly setting this value
...@@ -118,8 +122,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -118,8 +122,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale), (w2, w2_scale),
out=workspace3, out=output,
masked_m=expert_num_tokens, masked_m=expert_num_tokens,
expected_m=expected_m) expected_m=expected_m)
return workspace3
...@@ -64,6 +64,15 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -64,6 +64,15 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=self.block_shape, # type: ignore[arg-type] block_shape=self.block_shape, # type: ignore[arg-type]
) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None ) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None
assert (self.batched_deep_gemm_experts is not None
or self.batched_triton_experts is not None)
def supports_chunking(self) -> bool:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
return ((bdge is None or bdge.supports_chunking())
and (bte is None or bte.supports_chunking()))
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -73,7 +82,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -73,7 +82,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
...@@ -87,6 +96,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -87,6 +96,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -103,7 +113,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -103,7 +113,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
use_batched_deep_gemm_experts = (self.allow_deep_gemm use_batched_deep_gemm_experts = (self.allow_deep_gemm
and self.batched_deep_gemm_experts and self.batched_deep_gemm_experts
is not None) is not None)
...@@ -111,7 +121,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -111,7 +121,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if use_batched_deep_gemm_experts else if use_batched_deep_gemm_experts else
self.batched_triton_experts) self.batched_triton_experts)
assert experts is not None assert experts is not None
return experts.apply(hidden_states, w1, w2, topk_ids, activation, experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
global_num_experts, expert_map, w1_scale, global_num_experts, expert_map, w1_scale, w2_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace13, workspace2, expert_num_tokens) workspace2, expert_num_tokens)
...@@ -14,6 +14,7 @@ from vllm.scalar_type import scalar_types ...@@ -14,6 +14,7 @@ from vllm.scalar_type import scalar_types
def run_cutlass_moe_fp8( def run_cutlass_moe_fp8(
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -31,7 +32,8 @@ def run_cutlass_moe_fp8( ...@@ -31,7 +32,8 @@ def run_cutlass_moe_fp8(
out_dtype: torch.dtype, out_dtype: torch.dtype,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
) -> torch.Tensor: use_batched_format: bool,
):
a1q = hidden_states a1q = hidden_states
assert w1_scale is not None assert w1_scale is not None
...@@ -61,23 +63,20 @@ def run_cutlass_moe_fp8( ...@@ -61,23 +63,20 @@ def run_cutlass_moe_fp8(
if expert_map is not None: if expert_map is not None:
assert expert_num_tokens is None assert expert_num_tokens is None
# We have two modes: PPLX and non-PPLX. We differentiate them by checking # We have two modes: batched experts and non-batched experts.
# if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX # In the non-batched mode, the input tokens are not padded: thus, the shape
# uses to track the number of tokens per expert).
# In the non-PPLX mode, the input tokens are not padded: thus, the shape
# of the input is [total_num_tokens, hidden_size]. The input and output # of the input is [total_num_tokens, hidden_size]. The input and output
# require shuffling by a_map and c_map such that the tokens assigned to # require shuffling by a_map and c_map such that the tokens assigned to
# each expert are contiguous. # each expert are contiguous.
# In the PPLX mode, the input tokens are padded per expert to ensure that # In the batched mode, the input tokens are padded per expert to ensure that
# the PPLX dispatch and combine functions work correctly: thus, the shape # the batched dispatch and combine functions work correctly: thus, the shape
# of the input is [num_experts, max_num_tokens_per_expert, hidden_size]. # of the input is [num_experts, max_num_tokens_per_expert, hidden_size].
# The PPLX input and output require no shuffling by a_map and c_map since # The batched input and output require no shuffling by a_map and c_map since
# their tokens are already contiguous for each expert as a result of # their tokens are already contiguous for each expert as a result of
# the dispatch function. # the dispatch function.
is_pplx = expert_num_tokens is not None
M = a1q.shape[0] # no pplx M = a1q.shape[0] # non batched expert M
padded_M = a1q.shape[1] # pplx padded_M = a1q.shape[1] # batched expert M
_, K, N = w2.shape _, K, N = w2.shape
device = a1q.device device = a1q.device
...@@ -95,7 +94,9 @@ def run_cutlass_moe_fp8( ...@@ -95,7 +94,9 @@ def run_cutlass_moe_fp8(
topk = local_topk_ids.shape[1] topk = local_topk_ids.shape[1]
local_E = w1.shape[0] local_E = w1.shape[0]
if is_pplx: if use_batched_format:
assert expert_num_tokens is not None
expert_offsets = torch.empty((local_E), expert_offsets = torch.empty((local_E),
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
...@@ -167,7 +168,7 @@ def run_cutlass_moe_fp8( ...@@ -167,7 +168,7 @@ def run_cutlass_moe_fp8(
device=device, device=device,
dtype=torch.int64) dtype=torch.int64)
if is_pplx: if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N)) c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K)) c3 = _resize_cache(workspace13, (local_E * padded_M, K))
...@@ -192,12 +193,15 @@ def run_cutlass_moe_fp8( ...@@ -192,12 +193,15 @@ def run_cutlass_moe_fp8(
problem_sizes2, ab_strides2, ab_strides2, c_strides2, problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch) per_act_token, per_out_ch)
if is_pplx: if use_batched_format:
return c3.reshape(local_E, padded_M, K) output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
else: else:
return c3[c_map].view(M, topk, K) # We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# TODO (bnell): split class batched vs. non-batched?
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
...@@ -206,12 +210,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -206,12 +210,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: torch.dtype, out_dtype: torch.dtype,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
use_batched_format: bool = False,
): ):
super().__init__() super().__init__()
self.max_experts_per_worker = max_experts_per_worker self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.per_act_token = per_act_token self.per_act_token = per_act_token
self.per_out_ch = per_out_ch self.per_out_ch = per_out_ch
self.use_batched_format = use_batched_format
def supports_chunking(self) -> bool:
return not self.use_batched_format
def workspace_shapes( def workspace_shapes(
self, self,
...@@ -222,14 +231,24 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -222,14 +231,24 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.shape[1] padded_M = aq.shape[1]
workspace1 = self.max_experts_per_worker * padded_M * max(N, K) workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
workspace2 = self.max_experts_per_worker * padded_M * (N // 2) workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
return (workspace1, workspace2, self.out_dtype) output = (self.max_experts_per_worker, padded_M, K)
else:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M * topk, K)
return (workspace1, workspace2, output, self.out_dtype)
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -246,16 +265,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -246,16 +265,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o) activation_callable = lambda i, o: self.activation(activation, i, o)
return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids, run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts, activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype, expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch) self.per_act_token, self.per_out_ch,
self.use_batched_format)
def cutlass_moe_fp8( def cutlass_moe_fp8(
...@@ -325,6 +345,7 @@ def cutlass_moe_fp8( ...@@ -325,6 +345,7 @@ def cutlass_moe_fp8(
out_dtype=out_dtype, out_dtype=out_dtype,
per_act_token=per_act_token, per_act_token=per_act_token,
per_out_ch=per_out_ch, per_out_ch=per_out_ch,
use_batched_format=False,
), ),
) )
......
...@@ -70,6 +70,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -70,6 +70,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
super().__init__() super().__init__()
self.block_shape = deep_gemm_block_shape() self.block_shape = deep_gemm_block_shape()
def supports_chunking(self) -> bool:
return True
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -79,18 +82,18 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -79,18 +82,18 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
block_m = self.block_shape[0] block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m) M_sum = round_up(M_sum, block_m)
workspace1 = M_sum * max(N * 2, K) workspace1 = (M_sum, max(N * 2, K))
workspace2 = M_sum * max(N, K) workspace2 = (M_sum, max(N, K))
output = (M * topk, K)
return (workspace1, workspace2, a.dtype) return (workspace1, workspace2, output, a.dtype)
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -107,7 +110,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -107,7 +110,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
import deep_gemm as dg import deep_gemm as dg
a1q = hidden_states a1q = hidden_states
...@@ -143,7 +146,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -143,7 +146,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M_sum, N // 2)) (M_sum, N // 2))
mm2_out = _resize_cache(workspace2, (M_sum, K)) mm2_out = _resize_cache(workspace2, (M_sum, K))
out = _resize_cache(workspace13, (inv_perm.size(0), K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
...@@ -159,9 +161,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -159,9 +161,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
torch.index_select(mm2_out, 0, inv_perm, out=out) torch.index_select(mm2_out, 0, inv_perm, out=output)
return out
def deep_gemm_moe_fp8( def deep_gemm_moe_fp8(
......
...@@ -335,9 +335,6 @@ def invoke_moe_batched_triton_kernel( ...@@ -335,9 +335,6 @@ def invoke_moe_batched_triton_kernel(
BLOCK_M = config['BLOCK_SIZE_M'] BLOCK_M = config['BLOCK_SIZE_M']
BLOCK_N = config['BLOCK_SIZE_N'] BLOCK_N = config['BLOCK_SIZE_N']
BLOCK_K = config['BLOCK_SIZE_K'] BLOCK_K = config['BLOCK_SIZE_K']
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or max_num_tokens % BLOCK_M == 0)
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
triton.cdiv(B.size(1), BLOCK_N)) triton.cdiv(B.size(1), BLOCK_N))
...@@ -390,8 +387,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -390,8 +387,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
that the PPLX dispatch/combine kernels use. that the PPLX dispatch/combine kernels use.
""" """
def __init__(self, max_num_tokens: Optional[int], world_size: int, def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
dp_size: int, rank: int): rank: int):
super().__init__() super().__init__()
self.world_size = world_size self.world_size = world_size
self.dp_size = dp_size self.dp_size = dp_size
...@@ -430,11 +427,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -430,11 +427,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_tokens, hidden_dim = a1.size() num_tokens, hidden_dim = a1.size()
topk = topk_ids.size(1) topk = topk_ids.size(1)
if self.max_num_tokens is None:
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
self.max_num_tokens = int(tokens_per_expert.max().item())
else:
tokens_per_expert = torch.zeros(num_experts, tokens_per_expert = torch.zeros(num_experts,
dtype=torch.int, dtype=torch.int,
device=a1.device) device=a1.device)
...@@ -497,9 +489,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -497,9 +489,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
max_num_tokens: int,
world_size: int, world_size: int,
dp_size: int, dp_size: int,
max_num_tokens: Optional[int] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
...@@ -518,6 +510,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -518,6 +510,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.world_size = world_size self.world_size = world_size
self.dp_size = dp_size self.dp_size = dp_size
def supports_chunking(self) -> bool:
return False
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -527,18 +522,16 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -527,18 +522,16 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size num_dp = self.world_size // self.dp_size
max_num_tokens = a.size( workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
0) if self.max_num_tokens is None else self.max_num_tokens workspace2 = (self.max_num_tokens * num_dp, N)
#print(f"WORKSPACE {max_num_tokens} {num_dp}") return (workspace13, workspace2, workspace13, a.dtype)
workspace13 = num_experts * max_num_tokens * num_dp * K
workspace2 = max_num_tokens * num_dp * N
return (workspace13, workspace2, a.dtype)
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -555,20 +548,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -555,20 +548,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
assert hidden_states.dim() == 3 assert hidden_states.dim() == 3
assert expert_num_tokens is not None assert expert_num_tokens is not None
hidden_dim = hidden_states.size(-1)
if self.max_num_tokens is None:
max_num_tokens = hidden_states.size(1)
else:
max_num_tokens = self.max_num_tokens max_num_tokens = self.max_num_tokens
num_dp = self.world_size // self.dp_size num_dp = self.world_size // self.dp_size
num_experts = global_num_experts
out = _resize_cache(workspace13,
(num_experts, max_num_tokens * num_dp, hidden_dim))
num_local_experts = w1.size(0) num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), ( assert num_local_experts == w1.size(0), (
f"{num_local_experts} == {w1.size(0)}") f"{num_local_experts} == {w1.size(0)}")
...@@ -585,15 +570,13 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -585,15 +570,13 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
if (torch.compiler.is_compiling() if (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()): or torch.cuda.is_current_stream_capturing()):
num = max_num_tokens * num_dp num = hidden_states.shape[1]
else: else:
num = int(expert_num_tokens[expert].item()) num = int(expert_num_tokens[expert].item())
tmp = _resize_cache(workspace2, (num, N)) tmp = _resize_cache(workspace2, (num, N))
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
self.activation(activation, tmp, input) self.activation(activation, tmp, input)
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) output[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
return out
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...@@ -630,6 +613,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -630,6 +613,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert not use_int4_w4a16, "NYI" assert not use_int4_w4a16, "NYI"
assert self.block_shape is None, "NYI" assert self.block_shape is None, "NYI"
def supports_chunking(self) -> bool:
return False
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -639,17 +625,19 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -639,17 +625,19 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size num_dp = self.world_size // self.dp_size
max_num_tokens = a.size( max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens 0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
return (workspace13, workspace2, a.dtype) output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -666,7 +654,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -666,7 +654,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), ( assert hidden_states.size(-1) // 2 == w1.size(2), (
...@@ -723,8 +711,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -723,8 +711,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
(E, max_num_tokens, N)) (E, max_num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2, intermediate_cache2 = _resize_cache(workspace2,
(E, max_num_tokens, N // 2)) (E, max_num_tokens, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(E, max_num_tokens, K))
# MM1 # MM1
invoke_moe_batched_triton_kernel(A=hidden_states, invoke_moe_batched_triton_kernel(A=hidden_states,
...@@ -761,7 +747,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -761,7 +747,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
invoke_moe_batched_triton_kernel(A=qintermediate_cache2, invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2, B=w2,
C=intermediate_cache3, C=output,
expert_num_tokens=expert_num_tokens, expert_num_tokens=expert_num_tokens,
compute_type=compute_type, compute_type=compute_type,
A_scale=a2q_scale, A_scale=a2q_scale,
...@@ -772,4 +758,3 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -772,4 +758,3 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
config=config, config=config,
block_shape=self.block_shape) block_shape=self.block_shape)
return intermediate_cache3
...@@ -1542,6 +1542,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1542,6 +1542,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int4_w4a16=use_int4_w4a16) use_int4_w4a16=use_int4_w4a16)
self.per_channel_quant = per_channel_quant self.per_channel_quant = per_channel_quant
def supports_chunking(self) -> bool:
return True
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -1551,14 +1554,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1551,14 +1554,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
factor = num_experts if a.dim() == 3 else 1 workspace1 = (M, topk, max(N * 2, K))
workspace1 = M * topk * max(N * 2, K) * factor workspace2 = (M, topk, N)
workspace2 = M * topk * N * factor output = (M, topk, K)
return (workspace1, workspace2, a.dtype) return (workspace1, workspace2, output, a.dtype)
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -1575,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1575,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), ( assert hidden_states.size(-1) // 2 == w1.size(2), (
...@@ -1632,8 +1636,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1632,8 +1636,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
(num_tokens, top_k_num, N)) (num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2, intermediate_cache2 = _resize_cache(workspace2,
(num_tokens * top_k_num, N // 2)) (num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
...@@ -1671,7 +1673,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1671,7 +1673,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
invoke_fused_moe_kernel(qintermediate_cache2, invoke_fused_moe_kernel(qintermediate_cache2,
w2, w2,
intermediate_cache3, output,
a2q_scale, a2q_scale,
w2_scale, w2_scale,
w2_zp, w2_zp,
...@@ -1690,8 +1692,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1690,8 +1692,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_channel_quant, per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
return intermediate_cache3
def modular_triton_fused_moe( def modular_triton_fused_moe(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from math import prod
from typing import Optional from typing import Optional
import torch import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv
# #
# This file defines a set of base classes used to make MoE kernels more modular. # This file defines a set of base classes used to make MoE kernels more modular.
# The goal is to be able to utilize different communication mechanisms with # The goal is to be able to utilize different communication mechanisms with
...@@ -171,6 +176,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -171,6 +176,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
above. above.
""" """
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def workspace_shapes( def workspace_shapes(
self, self,
...@@ -181,19 +195,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -181,19 +195,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
""" """
Compute the number of elements for the temporary outputs of the two Compute the shapes for the temporary and final outputs of the two gemms
gemms and activation in the fused expert function. Since the and activation in the fused expert function. Since the gemms are
gemms are independent, the workspace for the first gemm can be shared independent, the workspace for the first gemm can be shared with the
with the workspace for the last gemm. workspace for the last gemm.
Returns a tuple of: Returns a tuple of:
- Number of workspace13 elements: must be large enough to hold the - workspace13 shape tuple: must be large enough to hold the
result of either expert gemm. result of either expert gemm.
- Number of workspace2 elements: must be large enough to hold the - workspace2 shape tuple: must be large enough to hold the
result of the activation function. result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors. - Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -210,6 +227,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -210,6 +227,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -226,12 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -226,12 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
""" """
This function computes the intermediate result of a Mixture of Experts This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2. (MoE) layer using two sets of weights, w1 and w2.
Parameters: Parameters:
- output: (torch.Tensor): The unweighted, unreduced output tensor.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer. layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
...@@ -259,13 +278,20 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -259,13 +278,20 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
function. function.
- expert_num_tokens: An optional tensor containing the number of tokens - expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input. assigned to each expert when using batched experts format input.
Returns:
- torch.Tensor: The unweighted, unreduced output tensor
""" """
raise NotImplementedError raise NotImplementedError
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
class FusedMoEModularKernel(torch.nn.Module): class FusedMoEModularKernel(torch.nn.Module):
""" """
This class combines a FusedMoEPrepareAndFinalize instance and This class combines a FusedMoEPrepareAndFinalize instance and
...@@ -288,61 +314,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -288,61 +314,6 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
def _do_fused_experts(
self,
a1: torch.Tensor, # input to forward fn
a1q: torch.Tensor, # output of prepare fn
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
expert_num_tokens: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
# Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
fused_out = self.fused_experts.apply(
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
return fused_out
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -408,12 +379,14 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -408,12 +379,14 @@ class FusedMoEModularKernel(torch.nn.Module):
_expert_topk_weights) = self.prepare_finalize.prepare( _expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids, a1, a1_scale, a2_scale, topk_weights, topk_ids,
global_num_experts, expert_map, apply_router_weight_on_input) global_num_experts, expert_map, apply_router_weight_on_input)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks. # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights) _expert_topk_weights)
fused_out = None fused_out = None
if a1q.numel() == 0: if a1q.numel() == 0:
# This happens when none of the tokens from the all2all reach this # This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph # EP rank. Also, note that this is only relevant for CUDAGraph
...@@ -423,13 +396,47 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -423,13 +396,47 @@ class FusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case. # and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else: else:
fused_out = self._do_fused_experts( _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
a1=a1,
a1q=a1q, if self.fused_experts.supports_chunking():
w1=w1, CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
w2=w2, num_chunks = cdiv(M, CHUNK_SIZE)
topk_ids=topk_ids, else:
expert_num_tokens=expert_num_tokens, CHUNK_SIZE = M
num_chunks = 1
if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts)
else:
# Use the full M to get the final output shape.
_, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = torch.zeros(prod(workspace13_shape),
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(prod(workspace2_shape),
device=a1.device,
dtype=workspace_dtype)
if num_chunks == 1:
fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_ids,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
...@@ -438,7 +445,58 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -438,7 +445,58 @@ class FusedMoEModularKernel(torch.nn.Module):
w1_zp=w1_zp, w1_zp=w1_zp,
w2_zp=w2_zp, w2_zp=w2_zp,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale) a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
else:
# The leading output dimension may not be equal to M, so
# we compute output indices separately.
M_out = fused_out_shape[0]
assert M_out >= M
factor = M_out // M
assert factor > 0
OUT_CHUNK_SIZE = CHUNK_SIZE * factor
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=workspace_dtype)
assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, (
f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}")
for chunk in range(num_chunks):
begin_chunk_idx = chunk * CHUNK_SIZE
end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M)
begin_out_idx = chunk * OUT_CHUNK_SIZE
end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out)
curr_a1q = a1q[begin_chunk_idx:end_chunk_idx]
curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx,
end_chunk_idx)
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
end_chunk_idx)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
self.fused_experts.apply(
fused_out[begin_out_idx:end_out_idx],
curr_a1q,
w1,
w2,
curr_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=curr_a1q_scale,
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights, self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input) topk_ids, apply_router_weight_on_input)
......
...@@ -34,6 +34,12 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -34,6 +34,12 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.deep_gemm_expert = DeepGemmExperts( self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None ) if self.allow_deep_gemm else None
def supports_chunking(self) -> bool:
dge = self.deep_gemm_expert
te = self.triton_expert
return ((dge is None or dge.supports_chunking())
and (te is None or te.supports_chunking()))
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -43,7 +49,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -43,7 +49,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int, K: int,
topk: int, topk: int,
num_experts: int, num_experts: int,
) -> tuple[int, int, torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
...@@ -57,6 +63,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -57,6 +63,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def apply( def apply(
self, self,
output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -73,31 +80,17 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -73,31 +80,17 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor: ):
N = w1.size(1) N = w1.size(1)
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2)): use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
assert self.deep_gemm_expert is not None and _valid_deep_gemm(hidden_states, w1, w2))
return self.deep_gemm_expert.apply(
hidden_states, experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
w1, assert experts is not None
w2,
topk_ids, experts.apply(
activation, output,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
)
else:
return self.triton_expert.apply(
hidden_states, hidden_states,
w1, w1,
w2, w2,
......
...@@ -562,9 +562,12 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -562,9 +562,12 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
(moe.num_experts + prepare_finalize.world_size - 1) // (moe.num_experts + prepare_finalize.world_size - 1) //
prepare_finalize.world_size) prepare_finalize.world_size)
experts = CutlassExpertsFp8( experts = CutlassExpertsFp8(
max_experts_per_worker, moe.in_dtype, max_experts_per_worker,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN, self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL) self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
use_batched_format=True,
)
if has_pplx and isinstance( if has_pplx and isinstance(
prepare_finalize, prepare_finalize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment