Unverified Commit fa98d777 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Kernel] DeepEP dispatch-combine kernel integration (#18434)


Signed-off-by: default avatarVarun <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent 01eee405
...@@ -516,9 +516,8 @@ void topk_softmax( ...@@ -516,9 +516,8 @@ void topk_softmax(
topk, topk,
stream); stream);
} }
else else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
{ {
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher( vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(), gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(), topk_weights.data_ptr<float>(),
...@@ -530,4 +529,17 @@ void topk_softmax( ...@@ -530,4 +529,17 @@ void topk_softmax(
topk, topk,
stream); stream);
} }
else {
assert(topk_indices.scalar_type() == at::ScalarType::Int64);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
} }
# SPDX-License-Identifier: Apache-2.0
"""
DeepEP test utilities
"""
import dataclasses
import importlib
import traceback
from typing import Callable, Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
## Parallel Processes Utils
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
## DeepEP specific utils
@dataclasses.dataclass
class DeepEPHTArgs:
num_local_experts: int
@dataclasses.dataclass
class DeepEPLLArgs:
max_tokens_per_rank: int
hidden_size: int
num_experts: int
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts,
quant_dtype=q_dtype,
block_shape=block_shape)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
quant_dtype=q_dtype,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype)
# SPDX-License-Identifier: Apache-2.0
"""
Test DeepEP + DeepGEMM integration
"""
import dataclasses
import importlib
from typing import Optional
import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from .deepep_utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
try:
import deep_gemm
has_deep_gemm = True
except ImportError:
has_deep_gemm = False
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from .deepep_utils import DeepEPHTArgs, make_deepep_a2a
if has_deep_gemm:
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep,
reason="Requires deep_ep kernels",
)
requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm,
reason="Requires deep_gemm kernels",
)
P = ParamSpec("P")
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
"""
dtype = torch.bfloat16
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device="cuda",
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device="cuda",
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
return w1, w2, w1_s, w2_s
@dataclasses.dataclass
class TestConfig:
topk: int
m: int
k: int
n: int
num_experts: int
block_size: list[int]
@dataclasses.dataclass
class TestTensors:
rank_tokens: torch.Tensor # all ranks make this many tokens
rank_token_scales: Optional[torch.Tensor]
topk: torch.Tensor
topk_weights: torch.Tensor
config: TestConfig
@staticmethod
def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16
topk, m, k, block_size = (config.topk, config.m, config.k,
config.block_size)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
rank_tokens = torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
block_k = block_size[1]
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
topk_ids = torch.randint(
low=0,
high=config.num_experts,
size=(m, topk),
device=torch.cuda.current_device()).to(dtype=torch.int64)
topk_weights = torch.randn(topk_ids.shape,
dtype=torch.float32,
device=torch.cuda.current_device())
return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk_ids,
topk_weights=topk_weights,
config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
block_shape: list[int]) -> FusedMoEModularKernel:
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
deepep_ll_args=None,
q_dtype=q_dtype,
block_shape=block_shape)
fused_experts = DeepGemmExperts()
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int) -> torch.Tensor:
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts, ),
fill_value=-1,
dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, dp_size, num_local_experts, q_dtype,
test_tensors.config.block_size)
a1_scale = test_tensors.rank_token_scales
out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=a1_scale,
a2_scale=None,
apply_router_weight_on_input=False)
return out
def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, block_shape: list[int]):
return fused_experts(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
# Make sure this is set to False so we
# dont end up comparing the same implementation.
allow_deep_gemm=False)
def _deep_ep_moe(
pgi: ProcessGroupInfo,
dp_size: int,
config: TestConfig,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
):
current_platform.seed_everything(pgi.rank)
w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device())
w1_scale = w1_scale.to(device=torch.cuda.current_device())
w2_scale = w2_scale.to(device=torch.cuda.current_device())
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [
w1.size(1) // w1_scale.size(1),
w1.size(2) // w1_scale.size(2)
]
with set_current_vllm_config(VllmConfig()):
# Reference
triton_moe = triton_impl(a=test_tensors.rank_tokens,
topk_ids=test_tensors.topk,
topk_weights=test_tensors.topk_weights,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=test_tensors.rank_token_scales,
block_shape=block_shape)
# Slice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
e_start = num_local_experts * pgi.rank
e_end = e_start + num_local_experts
w1_ep = w1[e_start:e_end]
w2_ep = w2[e_start:e_end]
w1_scale_ep = w1_scale[e_start:e_end]
w2_scale_ep = w2_scale[e_start:e_end]
deepep_moe = deep_ep_moe_impl(
pg,
pgi,
dp_size,
test_tensors,
w1_ep,
w2_ep,
w1_scale_ep,
w2_scale_ep,
config.num_experts,
)
torch.testing.assert_close(
triton_moe,
deepep_moe,
atol=6e-2,
rtol=6e-2,
)
MNKs = [
(8, 128, 128),
(8, 128, 512),
(8, 512, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),
(64, 1024, 1024),
(129, 128, 256),
(129, 1024, 2048),
(222, 1024, 2048),
]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
world_dp_size: tuple[int, int]):
m, n, k = mnk
current_platform.seed_everything(7)
if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
world_size, dp_size = world_dp_size
config = TestConfig(
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
block_size=block_size,
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)
parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2,
w1_scale, w2_scale)
# SPDX-License-Identifier: Apache-2.0
"""
Test deepep dispatch-combine logic
"""
import dataclasses
import importlib
from typing import Optional, Union
import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from .deepep_utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep,
reason="Requires deep_ep kernels",
)
MAX_TOKENS_PER_RANK = 64
def make_weights(
e, n, k, dtype
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
"""
if dtype in [torch.float16, torch.bfloat16]:
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
return w1, w2, None, None
# per-out-channel weight quantization
assert dtype == torch.float8_e4m3fn
w1 = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float16)
w2 = torch.empty((e, k, n), device="cuda", dtype=torch.float16)
n_b_scales = 2 * n
k_b_scales = k
w1_q = torch.empty_like(w1, dtype=dtype)
w2_q = torch.empty_like(w2, dtype=dtype)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=True)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=True)
return w1_q, w2_q, w1_scale, w2_scale
@dataclasses.dataclass
class TestConfig:
dtype: torch.dtype
topk: int
m: int
k: int
n: int
num_experts: int
@dataclasses.dataclass
class TestTensors:
rank_tokens: torch.Tensor # all ranks make this many tokens
rank_token_scales: Optional[torch.Tensor]
topk: torch.Tensor
topk_weights: torch.Tensor
config: TestConfig
@staticmethod
def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
# TODO (varun) - check that float16 works ?
assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn
else config.dtype)
rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
rank_token_scales = None
if config.dtype == torch.float8_e4m3fn:
# low_latency_mode kernels dont support per-token quant.
_, rank_token_scales = ops.scaled_fp8_quant(
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
topk = torch.randint(low=0,
high=config.num_experts,
size=(config.m, config.topk),
device="cuda").to(dtype=torch.int64)
topk_weights = torch.randn(topk.shape,
dtype=torch.float32,
device="cuda")
return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk,
topk_weights=topk_weights,
config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, hidden_size: int, dp_size: int,
num_experts: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None
if low_latency_mode:
ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK,
hidden_size=hidden_size,
num_experts=num_experts,
use_fp8_dispatch=use_fp8_dispatch)
else:
assert not use_fp8_dispatch, (
"FP8 Dispatch is valid only for low-latency kernels")
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \
make_deepep_a2a(pg = pg,
pgi = pgi,
dp_size = dp_size,
q_dtype = q_dtype,
block_shape = None,
deepep_ht_args = ht_args,
deepep_ll_args = ll_args)
if low_latency_mode:
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
world_size=pgi.world_size,
dp_size=dp_size,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False)
else:
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], num_experts: int,
use_fp8_dispatch: bool) -> torch.Tensor:
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts, ),
fill_value=-1,
dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
hidden_size = test_tensors.rank_tokens.size(1)
is_quantized = w1.dtype == torch.float8_e4m3fn
q_dtype = None
if is_quantized:
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
hidden_size, dp_size,
num_experts,
num_local_experts, q_dtype,
use_fp8_dispatch)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
rank_tokens_chunk = test_tensors.rank_tokens[chunk_start:chunk_end]
topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
topk_chunk = test_tensors.topk[chunk_start:chunk_end]
rank_token_scales_chunk = test_tensors.rank_token_scales
if rank_token_scales_chunk is not None and rank_token_scales_chunk.size(
0) == total_num_tokens:
# per act token
rank_token_scales_chunk = rank_token_scales_chunk[
chunk_start:chunk_end]
out = mk.forward(hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=rank_token_scales_chunk,
a2_scale=None,
apply_router_weight_on_input=False)
if not skip_result_store:
out_hidden_states[chunk_start:chunk_end, :].copy_(
out, non_blocking=True)
max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK
if low_latency_mode else total_num_tokens)
for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
chunk_start = chunk_start_
chunk_end = min(chunk_start + max_num_tokens_per_dp, total_num_tokens)
# clamp start and end
chunk_start = min(chunk_start, total_num_tokens - 1)
chunk_end = min(chunk_end, total_num_tokens)
process_chunk(chunk_start,
chunk_end,
skip_result_store=chunk_start_ >= total_num_tokens)
return out_hidden_states
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights)
if using_fp8_dispatch:
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant.
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
a.shape).to(a.dtype)
is_quantized = w1.dtype == torch.float8_e4m3fn
a_dtype = a.dtype
if is_quantized:
w1 = w1.to(dtype=torch.float32) * w1_scale
w2 = w2.to(dtype=torch.float32) * w2_scale
a = a.to(dtype=torch.float32)
m, _ = a.shape
topk = topk_ids.size(1)
out = torch.zeros_like(a)
for i in range(m):
a_i = a[i]
o_i = out[i]
for j in range(topk):
e = topk_ids[i][j]
e_w = topk_weights[i][j]
w1_e = w1[e]
w2_e = w2[e]
o_i += (SiluAndMul()
(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w
if is_quantized:
out = out.to(dtype=a_dtype)
return out
def _deep_ep_moe(
pgi: ProcessGroupInfo,
low_latency_mode: bool,
dp_size: int,
config: TestConfig,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
use_fp8_dispatch: bool,
):
if not low_latency_mode:
assert not use_fp8_dispatch, (
"FP8 dispatch interface is available only in low-latency mode")
is_quantized = w1.dtype == torch.float8_e4m3fn
w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device())
if is_quantized:
w1_scale = w1_scale.to( # type: ignore
device=torch.cuda.current_device())
w2_scale = w2_scale.to( # type: ignore
device=torch.cuda.current_device())
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode)
with set_current_vllm_config(VllmConfig()):
# Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch)
# Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
e_start = num_local_experts * pgi.rank
e_end = e_start + num_local_experts
w1_ep = w1[e_start:e_end]
w2_ep = w2[e_start:e_end]
w1_scale_ep, w2_scale_ep = None, None
if is_quantized:
w1_scale_ep = w1_scale[e_start:e_end] # type: ignore
w2_scale_ep = w2_scale[e_start:e_end] # type: ignore
deepep_combined = deep_ep_moe_impl(
pg,
pgi,
low_latency_mode,
dp_size,
test_tensors,
w1_ep,
w2_ep,
w1_scale_ep,
w2_scale_ep,
config.num_experts,
use_fp8_dispatch,
)
torch.testing.assert_close(
torch_combined,
deepep_combined,
atol=6e-2,
rtol=6e-2,
)
MNKs = [
(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),
(64, 1024, 1024),
(222, 1024, 2048),
]
DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int, world_dp_size: tuple[int,
int]):
low_latency_mode = False
use_fp8_dispatch = False
m, n, k = mnk
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype,
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts)
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
MNKs = [
(1, 128, 2560),
(2, 128, 2560),
(3, 1024, 2560),
(32, 128, 2560),
(45, 512, 2560),
(64, 1024, 2560),
(222, 1024, 2560),
]
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
USE_FP8_DISPATCH = [True, False]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@requires_deep_ep
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool):
low_latency_mode = True
m, n, k = mnk
if (low_latency_mode
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
pytest.skip(
f"Skipping test as hidden size {k} is not in list of supported "
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
)
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype,
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts)
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
...@@ -1856,6 +1856,8 @@ class ParallelConfig: ...@@ -1856,6 +1856,8 @@ class ParallelConfig:
factors.append(self.pipeline_parallel_size) factors.append(self.pipeline_parallel_size)
factors.append(self.tensor_parallel_size) factors.append(self.tensor_parallel_size)
factors.append(self.enable_expert_parallel) factors.append(self.enable_expert_parallel)
factors.append(self.data_parallel_size)
factors.append(envs.VLLM_ALL2ALL_BACKEND)
return hashlib.sha256(str(factors).encode()).hexdigest() return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None: def __post_init__(self) -> None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util import importlib.util
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -129,3 +129,147 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -129,3 +129,147 @@ class PPLXAll2AllManager(All2AllManagerBase):
from pplx_kernels.nvshmem import nvshmem_finalize from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize") logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize() nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
has_deepep = importlib.util.find_spec("deep_ep") is not None
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self.num_sms = 20
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
num_rdma_bytes = None
num_qps_per_rank = None
if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024
num_qps_per_rank = self.num_sms // 2
else:
assert self.intranode
num_rdma_bytes = 0
num_qps_per_rank = 1
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank)
def get_handle(self, kwargs):
assert len(kwargs) == 0, (
"DeepEPHTAll2AllManager expects no arguments. All the required "
"args are computed in the Manager itself.")
import deep_ep
buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
"""
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def _make_all2all_kwargs(
self,
max_num_tokens_per_dp_rank: int,
token_hidden_size: int,
num_ep_ranks: int,
num_global_experts: int,
num_local_experts: int,
) -> dict[Any, Any]:
"""
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
can dispatch all the ranks must hold the same value.
token_hidden_size: the hidden dimension of each token.
num_ep_ranks: the number of EP group ranks.
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import deep_ep
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
num_qps_per_rank = num_local_experts
num_rdma_bytes = None
if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024
else:
assert self.intranode
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=num_ep_ranks,
num_experts=num_global_experts)
assert num_rdma_bytes is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank)
def get_handle(self, kwargs):
"""
The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
"""
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle
...@@ -67,6 +67,14 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -67,6 +67,14 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import PPLXAll2AllManager from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group) self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.") logger.info("Using PPLX all2all manager.")
elif all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
logger.info("Using DeepEP High-Throughput all2all manager.")
elif all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
else: else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}") raise ValueError(f"Unknown all2all backend: {all2all_backend}")
......
...@@ -826,6 +826,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -826,6 +826,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Available options: # Available options:
# - "naive": naive all2all implementation using all-reduce # - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels # - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
"VLLM_ALL2ALL_BACKEND": "VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
......
...@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( ...@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute) _moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache) _resize_cache, per_token_group_quant_fp8)
from vllm.utils import round_up from vllm.utils import round_up
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -34,10 +34,8 @@ def _valid_deep_gemm_shape(M: int, N: int, K: int): ...@@ -34,10 +34,8 @@ def _valid_deep_gemm_shape(M: int, N: int, K: int):
return align <= M and N % align == 0 and K % align == 0 return align <= M and N % align == 0 and K % align == 0
def _valid_deep_gemm(hidden_states: torch.Tensor, def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor) -> bool:
w2: torch.Tensor,
expert_map: Optional[torch.Tensor] = None) -> bool:
""" """
Check if the given problem size is supported by the DeepGemm grouped Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be gemm kernel. All of M, N, K and the quantization block_shape must be
...@@ -47,10 +45,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, ...@@ -47,10 +45,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
logger.debug("DeepGemm disabled: deep_gemm not available.") logger.debug("DeepGemm disabled: deep_gemm not available.")
return False return False
if expert_map is not None:
logger.debug("DeepGemm disabled: expert map NYI.")
return False
M = hidden_states.size(0) M = hidden_states.size(0)
_, K, N = w2.size() _, K, N = w2.size()
if not _valid_deep_gemm_shape(M, N, K): if not _valid_deep_gemm_shape(M, N, K):
...@@ -116,7 +110,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -116,7 +110,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q = hidden_states a1q = hidden_states
_, N, K = w1.size() _, N, K = w1.size()
assert global_num_experts != -1 if global_num_experts == -1:
global_num_experts = w1.size(0)
assert w2.size(1) == K assert w2.size(1) == K
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
...@@ -128,6 +124,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -128,6 +124,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.block_shape[0], self.block_shape[0],
) )
if expert_map is not None:
# DeepGemm (Grouped Contiguous) kernel needs a valid B index
# for all rows of A. To that effect, simply compute with
# the 0th weight matrix.
# Note that this relies on the fact that corresponding topk
# weights would be 0 during weight multiplication.
expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
# Note: M_sum is different than the pre-permuted shape of a1q. # Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0) M_sum = a1q.size(0)
workspace1 = _resize_cache(workspace13, (M_sum, N)) workspace1 = _resize_cache(workspace13, (M_sum, N))
...@@ -140,9 +144,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -140,9 +144,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, workspace2, workspace1.view(-1, N)) self.activation(activation, workspace2, workspace1.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False, self.block_shape[1],
self.block_shape) column_major_scales=True)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Prepare/Finalize using DeepEP High-Throughput kernels.
"""
def __init__(self,
buffer: deep_ep.Buffer,
world_size: int,
rank: int,
dp_size: int,
rank_expert_offset: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
super().__init__()
self.buffer = buffer
self.world_size = world_size
self.rank = rank
self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset
self.quant_dtype = quant_dtype
self.block_shape = block_shape
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handle = None
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.int64
def _get_dispatch_config(self) -> Optional[deep_ep.Config]:
if self.dp_size not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_dispatch_config(self.dp_size)
def _get_combine_config(self) -> Optional[deep_ep.Config]:
if self.dp_size not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_combine_config(self.dp_size)
def _do_quant(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor], per_act_token: bool):
tokens, token_scales = moe_kernel_quantize_input(
tokens, token_scales, self.quant_dtype, per_act_token,
self.block_shape)
return tokens, token_scales
def _do_dispatch(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor, num_experts: int):
has_scales = token_scales is not None
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens,
is_token_in_rank, event) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
token_data = tokens
if has_scales:
token_data = (tokens, token_scales)
(
token_data, expert_topk_ids, expert_topk_weights,
expert_num_tokens_per_expert_list, self.handle, event
) = self.buffer.dispatch(
x=token_data,
handle=None,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens,
topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
# to this value.
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
if has_scales:
expert_x, expert_x_scale = token_data
else:
expert_x, expert_x_scale = token_data, None
# The existing MOE kernels assume that all entries of topk_ids are
# valid. To that effect, set the -1s in expert_topk_ids to some expert
# outside this rank so the expert_map can remap it to -1 when safe.
# With Expert Parallel, the experts are divided amongst the rank
# sequentially. For rank 0, set it to num_experts - 1 and for all other
# ranks set it to 0 as we know that expert_map will have a -1 in those
# regions for those ranks.
#
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
expert_topk_ids = torch.where(
expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
# Check if there is a block_shape / or if we can infer the quantization
# schemes from the scales.
per_token_quant = None
if all([x is None for x in [self.block_shape, a1_scale, a2_scale]
]) and self.quant_dtype is not None:
# Quantization required despite none of the inputs suggesting
# quantization. Fallback to per_token_dynamic quant.
per_token_quant = True
else:
per_token_quant = ((self.block_shape is not None) or
(a1_scale is not None and a1_scale.numel() != 1)
or (a2_scale is not None
and a2_scale.numel() != 1))
if per_token_quant:
a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=rank_topk_ids,
rank_topk_weights=rank_topk_weights,
num_experts=num_experts)
else:
# DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16.
(expert_x, _, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=rank_topk_ids,
rank_topk_weights=rank_topk_weights,
num_experts=num_experts)
# quantize now
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = self._do_quant(expert_x,
a1_scale,
per_act_token=False)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights)
def _apply_weights_and_reduce(self, num_tokens: int,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
apply_router_weight_on_input: bool,
output_dtype: torch.dtype):
if fused_expert_output.ndim == 2:
hidden_dim = fused_expert_output.size(-1)
fused_expert_output = fused_expert_output.view(
num_tokens, -1, hidden_dim)
if not apply_router_weight_on_input:
# The DeepEP combine kernels don't do the topk weight
# multiplication. We multiply the weights locally.
fused_expert_output = fused_expert_output.to(torch.float32)
fused_expert_output = fused_expert_output * topk_weights.view(
fused_expert_output.size(0), -1, 1)
fused_expert_output = fused_expert_output.to(output_dtype)
return fused_expert_output.sum(dim=1).to(output_dtype)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0:
fused_expert_output = self._apply_weights_and_reduce(
num_tokens=topk_ids.size(0),
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
apply_router_weight_on_input=apply_router_weight_on_input,
output_dtype=output.dtype)
combined_x, _, event = self.buffer.combine(
x=fused_expert_output,
handle=self.handle,
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
def dequant_fp8(expert_x_fp8: torch.Tensor,
expert_x_scales: torch.Tensor) -> torch.Tensor:
"""
Return dequantized tensor in fp32
"""
# TODO (varun) : Optimize leverage num_tokens_per_expert counts
assert expert_x_fp8.is_contiguous()
expert_x_scales = expert_x_scales.contiguous()
num_experts = expert_x_fp8.size(0)
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE)
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape)
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Prepare/Finalize using DeepEP low-latency kernels.
"""
# DeepEP low-latency kernels are compiled only for certain
# specific hidden sizes.
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
def __init__(self,
buffer: deep_ep.Buffer,
world_size: int,
dp_size: int,
max_tokens_per_rank: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
use_fp8_dispatch: bool = False):
super().__init__()
self.buffer = buffer
self.world_size = world_size
self.dp_size = dp_size
self.quant_dtype = quant_dtype
self.block_shape = block_shape
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handle = None
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_tokens_per_rank
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.int64
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
(f"Hidden Size {hidden_size} not in supported list of hidden sizes"
f"{self.SUPPORTED_HIDDEN_SIZES}")
if self.use_fp8_dispatch:
assert hidden_size % 128 == 0, \
"DeepEP kernels quantize the inputs in blocks of shape 128"
# Quantize
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
assert not per_act_token, (
"low_latency kernels don't support per-act-token quant")
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
# Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \
self.buffer.low_latency_dispatch(a1,
rank_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=False)
if self.use_fp8_dispatch:
# TODO (varun) : In the case of dynamic quantization, we could
# probably skip the quant below and use the results directly.
# Although note that the deepep quant is per token 128 elements.
expert_x_fp8, expert_x_scales = expert_x
expert_x = dequant_fp8(expert_x_fp8,
expert_x_scales).to(dtype=a1.dtype)
num_experts = expert_x.size(0)
hidden_dim = expert_x.size(-1)
expert_x = expert_x.view((-1, expert_x.size(-1)))
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x, a1_scale, self.quant_dtype, per_act_token,
self.block_shape)
expert_x = expert_x.view((num_experts, -1, hidden_dim))
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
assert self.handle is not None
combine_topk_weights = topk_weights
if apply_router_weight_on_input:
# weights have already been applied.
combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode
_, event, hook = self.buffer.low_latency_combine(
fused_expert_output,
topk_ids,
combine_topk_weights,
self.handle,
async_finish=False,
zero_copy=False,
return_recv_hook=False,
out=output)
...@@ -10,7 +10,8 @@ import triton.language as tl ...@@ -10,7 +10,8 @@ import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config) get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
@triton.jit @triton.jit
...@@ -397,6 +398,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -397,6 +398,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.rank = rank self.rank = rank
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return None
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
...@@ -407,7 +414,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -407,7 +414,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
assert a1.dim() == 2 assert a1.dim() == 2
assert topk_ids.dim() == 2 assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0) assert topk_ids.size(0) == a1.size(0)
...@@ -450,7 +458,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -450,7 +458,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
first_expert, :rows, :] = a1[:topks.numel()][topks] first_expert, :rows, :] = a1[:topks.numel()][topks]
tokens_per_expert[expert_id - first_expert] = rows tokens_per_expert[expert_id - first_expert] = rows
return b_a1, a1_scale, tokens_per_expert return b_a1, a1_scale, tokens_per_expert, None, None
def finalize( def finalize(
self, self,
...@@ -601,6 +609,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -601,6 +609,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
world_size: int = 1, world_size: int = 1,
dp_size: int = 1, dp_size: int = 1,
...@@ -611,12 +620,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -611,12 +620,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16 self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16 self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape self.block_shape = block_shape
self.per_channel_quant = per_channel_quant
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
self.world_size = world_size self.world_size = world_size
self.dp_size = dp_size self.dp_size = dp_size
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
assert self.block_shape is None, "NYI"
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -670,8 +682,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -670,8 +682,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
] ]
# TODO: num_tokens -> max_num_tokens? E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids) hidden_states, w1, w2, topk_ids)
assert w1.size(0) == E assert w1.size(0) == E
...@@ -687,7 +698,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -687,7 +698,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2.size(), w2.size(),
top_k_num, top_k_num,
config_dtype, config_dtype,
num_tokens, max_num_tokens,
block_shape=self.block_shape, block_shape=self.block_shape,
) )
...@@ -706,10 +717,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -706,10 +717,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
# We can reuse the memory between these because by the time we need # We can reuse the memory between these because by the time we need
# cache3, we're done with cache1 # cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) intermediate_cache1 = _resize_cache(workspace13,
(E, max_num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2, intermediate_cache2 = _resize_cache(workspace2,
(E, num_tokens, N // 2)) (E, max_num_tokens, N // 2))
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) intermediate_cache3 = _resize_cache(workspace13,
(E, max_num_tokens, K))
# MM1 # MM1
invoke_moe_batched_triton_kernel(A=hidden_states, invoke_moe_batched_triton_kernel(A=hidden_states,
...@@ -731,15 +744,20 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -731,15 +744,20 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, intermediate_cache2.view(-1, N // 2), self.activation(activation, intermediate_cache2.view(-1, N // 2),
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
#qintermediate_cache2 = intermediate_cache2 ic2_hidden_size = intermediate_cache2.size(-1)
a2q_scale = a2_scale intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size)
# TODO (varun) : support w8a8
assert not self.use_fp8_w8a8 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
#if self.use_fp8_w8a8: A=intermediate_cache2,
# qintermediate_cache2, a2q_scale = _fp8_quantize( A_scale=a2_scale,
# intermediate_cache2, a2_scale, self.block_shape) qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
qintermediate_cache2 = qintermediate_cache2.view(
(E, -1, ic2_hidden_size))
invoke_moe_batched_triton_kernel(A=intermediate_cache2, invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2, B=w2,
C=intermediate_cache3, C=intermediate_cache3,
expert_num_tokens=expert_num_tokens, expert_num_tokens=expert_num_tokens,
...@@ -752,5 +770,4 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -752,5 +770,4 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
config=config, config=config,
block_shape=self.block_shape) block_shape=self.block_shape)
return intermediate_cache3 return intermediate_cache3
...@@ -1164,7 +1164,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1164,7 +1164,7 @@ def fused_experts(hidden_states: torch.Tensor,
# permute/unpermute ops are available. # permute/unpermute ops are available.
N = w1.shape[1] N = w1.shape[1]
if (allow_deep_gemm and use_fp8_w8a8 and N > 512 if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): and _valid_deep_gemm(hidden_states, w1, w2)):
assert apply_router_weight_on_input is False assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8( return deep_gemm_moe_fp8(
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -5,7 +5,7 @@ import importlib ...@@ -5,7 +5,7 @@ import importlib
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -30,16 +30,19 @@ from vllm.platforms.interface import CpuArchEnum ...@@ -30,16 +30,19 @@ from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
has_pplx = importlib.util.find_spec("pplx_kernels") is not None has_pplx = importlib.util.find_spec("pplx_kernels") is not None
has_deepep = importlib.util.find_spec("deep_ep") is not None
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import (BatchedPrepareAndFinalize, from .fused_batched_moe import BatchedTritonExperts
BatchedTritonExperts)
from .fused_moe import TritonExperts, fused_experts from .fused_moe import TritonExperts, fused_experts
from .modular_kernel import (FusedMoEModularKernel, from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
if has_pplx: if has_pplx:
from .pplx_prepare_finalize import PplxPrepareAndFinalize from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore
...@@ -71,10 +74,24 @@ class FusedMoEParallelConfig: ...@@ -71,10 +74,24 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not use_ep: bool # whether to use EP or not
@property
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@property @property
def use_pplx_kernels(self): def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and \ return (self.use_all2all_kernels
envs.VLLM_ALL2ALL_BACKEND == "pplx" and envs.VLLM_ALL2ALL_BACKEND == "pplx")
@property
def use_deepep_ht_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
@property
def use_deepep_ll_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
...@@ -231,6 +248,14 @@ class MoEConfig: ...@@ -231,6 +248,14 @@ class MoEConfig:
def use_pplx_kernels(self): def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
class FusedMoeWeightScaleSupported(Enum): class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
...@@ -252,7 +277,16 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -252,7 +277,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
prepare_finalize = None quant_dtype = None
act_quant_block_size = None
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
if isinstance(quant_config, Fp8Config):
act_quant_block_size = quant_config.weight_block_size
quant_dtype = torch.float8_e4m3fn
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
DeepEPHTPrepareAndFinalize,
DeepEPLLPrepareAndFinalize]] = None
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
all_to_all_args = dict( all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens, max_num_tokens=moe.max_num_tokens,
...@@ -288,8 +322,49 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -288,8 +322,49 @@ class FusedMoEMethodBase(QuantizeMethodBase):
dp_size=all2all_manager.tp_group.world_size, dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype, quant_dtype=moe.in_dtype,
) )
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict()
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = DeepEPHTPrepareAndFinalize(
handle,
world_size=all2all_manager.world_size,
rank=all2all_manager.rank,
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank *
moe.num_local_experts,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
)
elif moe.use_deepep_ll_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
num_ep_ranks=all2all_manager.world_size,
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts //
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
world_size=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
max_tokens_per_rank=moe.max_num_tokens,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
use_fp8_dispatch=False,
)
self.topk_indices_dtype = None
if prepare_finalize is not None: if prepare_finalize is not None:
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize) experts = self.select_gemm_impl(prepare_finalize)
self.fused_experts = FusedMoEModularKernel( self.fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
...@@ -297,7 +372,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -297,7 +372,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
) )
def select_gemm_impl( def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] self, prepare_finalize: FusedMoEPrepareAndFinalize
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate # based on the all2all implementation, select the appropriate
# gemm implementation # gemm implementation
...@@ -334,6 +409,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -334,6 +409,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: MoEConfig): def __init__(self, moe: MoEConfig):
super().__init__() super().__init__()
self.fused_experts = fused_experts # type: ignore self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
...@@ -343,8 +419,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -343,8 +419,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else: else:
self.rocm_aiter_fused_experts = None # type: ignore self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl( def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize):
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]):
assert self.fused_experts == fused_experts assert self.fused_experts == fused_experts
...@@ -353,11 +428,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -353,11 +428,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize, use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): ) is not None
if use_batched_experts:
logger.debug("BatchedTritonExperts %s", self.moe) logger.debug("BatchedTritonExperts %s", self.moe)
assert self.moe.dp_size == all2all_manager.dp_world_size
experts = BatchedTritonExperts( experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE, max_num_tokens=self.moe.max_num_tokens,
world_size=all2all_manager.world_size, world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels # dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size, dp_size=all2all_manager.tp_group.world_size,
...@@ -366,6 +443,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -366,6 +443,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
block_shape=None, block_shape=None,
per_channel_quant=False,
) )
else: else:
logger.debug("TritonExperts %s", self.moe) logger.debug("TritonExperts %s", self.moe)
...@@ -494,6 +572,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -494,6 +572,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -505,7 +584,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -505,7 +584,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32 if self.moe.use_pplx_kernels else None) indices_type=self.topk_indices_dtype)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
assert expert_map is None assert expert_map is None
...@@ -806,11 +885,8 @@ class FusedMoE(torch.nn.Module): ...@@ -806,11 +885,8 @@ class FusedMoE(torch.nn.Module):
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
if quant_config is None: else quant_config.get_quant_method(self, prefix))
quant_method = UnquantizedFusedMoEMethod(moe)
else:
quant_method = quant_config.get_quant_method(self, prefix)
assert quant_method is not None assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase) assert isinstance(quant_method, FusedMoEMethodBase)
...@@ -836,7 +912,8 @@ class FusedMoE(torch.nn.Module): ...@@ -836,7 +912,8 @@ class FusedMoE(torch.nn.Module):
# Chunked all2all staging tensor # Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None
if self.moe_parallel_config.use_pplx_kernels: if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
act_dtype = vllm_config.model_config.dtype act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros( self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size), (MOE_DP_CHUNK_SIZE, self.hidden_size),
...@@ -880,6 +957,14 @@ class FusedMoE(torch.nn.Module): ...@@ -880,6 +957,14 @@ class FusedMoE(torch.nn.Module):
def use_pplx_kernels(self): def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
...@@ -1210,19 +1295,21 @@ class FusedMoE(torch.nn.Module): ...@@ -1210,19 +1295,21 @@ class FusedMoE(torch.nn.Module):
When just tensor-parallel is used, it is not required to reduce When just tensor-parallel is used, it is not required to reduce
the shared_experts results immediately. Instead we reduce at the the shared_experts results immediately. Instead we reduce at the
once at the end of the MoE op. (Refer to DeepSeekV2MoE module) once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
With EP and the pplx kernels - this is no longer viable as all With EP and all2all kernels - this is no longer viable as all
GPU ranks in DP, produce the complete set of hidden_states. GPU ranks in DP, produce the complete set of hidden_states.
Therefore it is required that we reduce the shared_experts output Therefore it is required that we reduce the shared_experts output
early. early.
""" """
return self.use_pplx_kernels return (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels)
def maybe_all_reduce_tensor_model_parallel( def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor): self, final_hidden_states: torch.Tensor):
""" """
The pplx combine kernel reduces across GPU ranks by default. The pplx combine kernel reduces across GPU ranks by default.
""" """
if self.use_pplx_kernels: if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels):
return final_hidden_states return final_hidden_states
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -1289,7 +1376,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1289,7 +1376,7 @@ class FusedMoE(torch.nn.Module):
ctx = get_forward_context() ctx = get_forward_context()
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
num_tokens = full_hidden_states.size(0) num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp, for chunk_start_ in range(0, max_tokens_across_dp,
...@@ -1310,12 +1397,17 @@ class FusedMoE(torch.nn.Module): ...@@ -1310,12 +1397,17 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
assert self.quant_method is not None assert self.quant_method is not None
if self.moe_parallel_config.use_pplx_kernels: if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl_chunked(hidden_states, router_logits)
if self.dp_size > 1: do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
...@@ -1335,12 +1427,12 @@ class FusedMoE(torch.nn.Module): ...@@ -1335,12 +1427,12 @@ class FusedMoE(torch.nn.Module):
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
) )
if self.dp_size > 1: if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states) final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.) # Default set to False. (May have to add shared expert outputs.
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states) final_hidden_states)
return final_hidden_states return final_hidden_states
......
...@@ -94,7 +94,8 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -94,7 +94,8 @@ class FusedMoEPrepareAndFinalize(ABC):
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Perform any quantization (and/or) dispatching needed Perform any quantization (and/or) dispatching needed
for this kernel. for this kernel.
...@@ -113,6 +114,10 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -113,6 +114,10 @@ class FusedMoEPrepareAndFinalize(ABC):
Returns a tuple of: Returns a tuple of:
- quantized + dispatched a. - quantized + dispatched a.
- quantized + dispatched a1_scales. - quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
""" """
raise NotImplementedError raise NotImplementedError
...@@ -138,6 +143,27 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -138,6 +143,27 @@ class FusedMoEPrepareAndFinalize(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def topk_indices_dtype(self) -> Optional[torch.dtype]:
"""
The PrepareFinalize All2All implementations generally constrain the
dtype of the topk_ids they support. This function returns the
required topk indices dtype so it can be respected.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@abstractmethod
def max_num_tokens_per_rank(self) -> Optional[int]:
"""
Some PrepareFinalize All2All implementations are batched. Meaning,
they can processes only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
Return None if there are no such restrictions.
"""
raise NotImplementedError
class FusedMoEPermuteExpertsUnpermute(ABC): class FusedMoEPermuteExpertsUnpermute(ABC):
""" """
...@@ -261,6 +287,61 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -261,6 +287,61 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
def _do_fused_experts(
self,
a1: torch.Tensor, # input to forward fn
a1q: torch.Tensor, # output of prepare fn
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
expert_num_tokens: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
# Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
fused_out = self.fused_experts.apply(
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
return fused_out
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -315,36 +396,39 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -315,36 +396,39 @@ class FusedMoEModularKernel(torch.nn.Module):
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
a1 = hidden_states
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
if global_num_experts == -1:
global_num_experts = E
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1) output = a1 if inplace else torch.zeros_like(a1)
workspace13_shape, workspace2_shape, workspace_dtype = ( if global_num_experts == -1:
self.fused_experts.workspace_shapes(a1, M, N, K, top_k, global_num_experts = w1.size(0)
global_num_experts))
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
# We can reuse the memory between cache1 and cache3 because by the time _expert_topk_weights) = self.prepare_finalize.prepare(
# we need cache3, we're done with cache1 a1, a1_scale, a2_scale, topk_weights, topk_ids,
workspace13 = torch.zeros(workspace13_shape, global_num_experts, expert_map, apply_router_weight_on_input)
device=a1.device, # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
dtype=workspace_dtype) topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
workspace2 = torch.zeros(workspace2_shape, topk_weights = (topk_weights if _expert_topk_weights is None else
device=a1.device, _expert_topk_weights)
dtype=workspace_dtype)
fused_out = None
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare( if a1q.numel() == 0:
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, # This happens when none of the tokens from the all2all reach this
expert_map, apply_router_weight_on_input) # EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
fused_out = self.fused_experts.apply( # kernels. CUDAGraph compatible all2all kernels like the pplx
a1q, # kernels and the DeepEP low-latency kernels are always batched
w1, # and can never run into the tensor.numel() == 0 case.
w2, fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
topk_ids, else:
fused_out = self._do_fused_experts(
a1=a1,
a1q=a1q,
w1=w1,
w2=w2,
topk_ids=topk_ids,
expert_num_tokens=expert_num_tokens,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
...@@ -353,11 +437,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -353,11 +437,7 @@ class FusedMoEModularKernel(torch.nn.Module):
w1_zp=w1_zp, w1_zp=w1_zp,
w2_zp=w2_zp, w2_zp=w2_zp,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale, a2_scale=a2_scale)
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights, self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input) topk_ids, apply_router_weight_on_input)
......
...@@ -25,7 +25,7 @@ def _moe_permute( ...@@ -25,7 +25,7 @@ def _moe_permute(
""" """
top_k_num = curr_topk_ids.size(1) top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.sizze(0) tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, moe_align_block_size(curr_topk_ids,
...@@ -37,11 +37,12 @@ def _moe_permute( ...@@ -37,11 +37,12 @@ def _moe_permute(
inv_perm: Optional[torch.Tensor] = None inv_perm: Optional[torch.Tensor] = None
num_tokens = top_k_num * tokens_in_chunk num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids. # Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states, curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num) sorted_token_ids // top_k_num)
......
...@@ -32,6 +32,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -32,6 +32,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.dp_size = dp_size self.dp_size = dp_size
self.quant_dtype = quant_dtype self.quant_dtype = quant_dtype
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.uint32
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
...@@ -42,7 +48,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -42,7 +48,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
...@@ -115,7 +122,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -115,7 +122,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
bound_m=bound_m, bound_m=bound_m,
) )
return expert_x, expert_x_scale, expert_num_tokens return expert_x, expert_x_scale, expert_num_tokens, None, None
def finalize( def finalize(
self, self,
......
...@@ -24,6 +24,12 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -24,6 +24,12 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
self.block_shape = block_shape self.block_shape = block_shape
self.quant_dtype = quant_dtype self.quant_dtype = quant_dtype
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return None
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
...@@ -34,7 +40,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -34,7 +40,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1 # TODO: this only works for topK=1, will need to update for topK>1
...@@ -47,7 +55,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -47,7 +55,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
self.per_channel_quant, self.per_channel_quant,
self.block_shape) self.block_shape)
return a1q, a1q_scale, None return a1q, a1q_scale, None, None, None
def finalize( def finalize(
self, self,
......
...@@ -29,9 +29,10 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -29,9 +29,10 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
block_m=block_m) block_m=block_m)
self.deep_gemm_expert = DeepGemmExperts()
self.allow_deep_gemm = allow_deep_gemm self.allow_deep_gemm = allow_deep_gemm
self.use_fp8_w8a8 = use_fp8_w8a8 self.use_fp8_w8a8 = use_fp8_w8a8
self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None
def workspace_shapes( def workspace_shapes(
self, self,
...@@ -46,6 +47,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -46,6 +47,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes( return self.deep_gemm_expert.workspace_shapes(
a, M, N, K, topk, num_experts) a, M, N, K, topk, num_experts)
else: else:
...@@ -73,7 +75,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -73,7 +75,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
) -> torch.Tensor: ) -> torch.Tensor:
N = w1.size(1) N = w1.size(1)
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): and _valid_deep_gemm(hidden_states, w1, w2)):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.apply( return self.deep_gemm_expert.apply(
hidden_states, hidden_states,
w1, w1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment