Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes from vllm.utils import get_max_shared_memory_bytes
...@@ -449,7 +450,8 @@ def test_multi_query_kv_attention( ...@@ -449,7 +450,8 @@ def test_multi_query_kv_attention(
start += seq_len start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl. # xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias = [ alibi_bias = [
b.materialize(b.shape, device=device).squeeze() for b in attn_bias b.materialize((1, num_query_heads, i, i), device=device).squeeze()
for b, i in zip(attn_bias, seq_lens)
] ]
else: else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
...@@ -506,3 +508,18 @@ def test_multi_query_kv_attention_with_alibi( ...@@ -506,3 +508,18 @@ def test_multi_query_kv_attention_with_alibi(
device, device,
use_alibi=True, use_alibi=True,
) )
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64
scale = float(1.0 / (head_size**0.5))
num_heads = 16
num_kv_heads = 5
with pytest.raises(AssertionError):
_ = attention_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
)
...@@ -106,10 +106,8 @@ def test_env( ...@@ -106,10 +106,8 @@ def test_env(
block_size, block_size,
False, False,
use_mla=use_mla) use_mla=use_mla)
if use_v1 and name != "TRITON_MLA": expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == f"{name}_VLLM_V1" assert backend.get_name() == expected
else:
assert backend.get_name() == name
else: else:
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
get_attn_backend(16, get_attn_backend(16,
...@@ -173,7 +171,7 @@ def test_env( ...@@ -173,7 +171,7 @@ def test_env(
expected = "FLASHINFER_VLLM_V1" if use_v1 else name expected = "FLASHINFER_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected assert backend.get_name() == expected
else: else:
backend = get_attn_backend(16, backend = get_attn_backend(32,
torch.float16, torch.float16,
torch.float16, torch.float16,
block_size, block_size,
...@@ -182,6 +180,45 @@ def test_env( ...@@ -182,6 +180,45 @@ def test_env(
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected assert backend.get_name() == expected
if use_v1:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == "FLEX_ATTENTION", (
"Should fallback to FlexAttention if head size is "
"not supported by FlashAttention")
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])
def test_fp32_fallback(
device: str,
use_v1: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Test attention backend selection with fp32."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1 else "TORCH_SDPA")
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "FLEX_ATTENTION"
if use_v1 else "XFORMERS")
def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
......
...@@ -72,8 +72,8 @@ def test_copy_blocks( ...@@ -72,8 +72,8 @@ def test_copy_blocks(
# destination blocks. # destination blocks.
assert 2 * num_mappings <= num_blocks assert 2 * num_mappings <= num_blocks
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) dst_blocks = random.sample(remaining_blocks, 2 * num_mappings)
block_mapping: list[tuple[int, int]] = [] block_mapping: list[tuple[int, int]] = []
for i in range(num_mappings): for i in range(num_mappings):
src = src_blocks[i] src = src_blocks[i]
...@@ -189,12 +189,12 @@ def test_reshape_and_cache( ...@@ -189,12 +189,12 @@ def test_reshape_and_cache(
# Run the reference implementation. # Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies_lst = block_indicies.cpu().tolist() block_indices_lst = block_indices.cpu().tolist()
block_offsets = slot_mapping % block_size block_offsets = slot_mapping % block_size
block_offsets_lst = block_offsets.cpu().tolist() block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens): for i in range(num_tokens):
block_idx = block_indicies_lst[i] block_idx = block_indices_lst[i]
block_offset = block_offsets_lst[i] block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
...@@ -322,12 +322,12 @@ def test_reshape_and_cache_flash( ...@@ -322,12 +322,12 @@ def test_reshape_and_cache_flash(
kv_dtype=kv_cache_dtype) kv_dtype=kv_cache_dtype)
# Run the reference implementation. # Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies_lst = block_indicies.cpu().tolist() block_indices_lst = block_indices.cpu().tolist()
block_offsets = slot_mapping % block_size block_offsets = slot_mapping % block_size
block_offsets_lst = block_offsets.cpu().tolist() block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens): for i in range(num_tokens):
block_idx = block_indicies_lst[i] block_idx = block_indices_lst[i]
block_offset = block_offsets_lst[i] block_offset = block_offsets_lst[i]
if kv_cache_layout == "NHD": if kv_cache_layout == "NHD":
cloned_key_cache[block_idx, block_offset, :, :] = key[i] cloned_key_cache[block_idx, block_offset, :, :] = key[i]
......
...@@ -46,7 +46,7 @@ CUDA_DEVICE = "cuda:0" ...@@ -46,7 +46,7 @@ CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS = [128] MAX_DEC_SEQ_LENS = [128]
MAX_ENC_SEQ_LENS = [128] MAX_ENC_SEQ_LENS = [128]
# Narrow teest-cases for unsupported-scenario # Narrow test-cases for unsupported-scenario
# tests # tests
HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]]
...@@ -99,7 +99,7 @@ class TestResources(NamedTuple): ...@@ -99,7 +99,7 @@ class TestResources(NamedTuple):
Attributes: Attributes:
* scale: 1/sqrt(d) scale factor for attn * scale: 1/sqrt(d) scale factor for attn
* attn_backend: implementatino of abstraction * attn_backend: implementations of abstraction
attention interface using attention interface using
a particular kernel library a particular kernel library
i.e. XFormers i.e. XFormers
......
...@@ -7,10 +7,7 @@ from torch import Tensor ...@@ -7,10 +7,7 @@ from torch import Tensor
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv
def cdiv(a, b):
return (a + b - 1) // b
def ref_mla( def ref_mla(
......
...@@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True) False, True)
assert backend.get_name() == "TRITON_MLA" assert (backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1")
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
...@@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, None) m.setenv(STR_BACKEND_ENV_VAR, None)
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True) False, True)
assert backend.get_name() == "TRITON_MLA" assert (backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1")
# change the attention backend to AITER MLA # change the attention backend to AITER MLA
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
......
...@@ -5,10 +5,7 @@ import pytest ...@@ -5,10 +5,7 @@ import pytest
import torch import torch
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils import cdiv
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [3, 5]) @pytest.mark.parametrize("B", [3, 5])
......
...@@ -39,10 +39,10 @@ def rotary_embedding_opcheck(rot, ...@@ -39,10 +39,10 @@ def rotary_embedding_opcheck(rot,
@pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.parametrize("use_key", [True, False]) @pytest.mark.parametrize("use_key", [True, False])
@pytest.mark.parametrize("head_stride_is_contingous", [True, False]) @pytest.mark.parametrize("head_stride_is_contiguous", [True, False])
def test_rotary_embedding_opcheck(dist_init, device, max_position, def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size, is_neox_style, rotary_dim, head_size,
seq_len, use_key, head_stride_is_contingous): seq_len, use_key, head_stride_is_contiguous):
batch_size = 1 batch_size = 1
base = 10000 base = 10000
num_heads = 7 num_heads = 7
...@@ -52,7 +52,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, ...@@ -52,7 +52,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions = torch.randint(0, positions = torch.randint(0,
max_position, (batch_size, seq_len), max_position, (batch_size, seq_len),
device=device) device=device)
head_stride = head_size + (64 if head_stride_is_contingous else 0) head_stride = head_size + (64 if head_stride_is_contiguous else 0)
query = torch.randn(batch_size, query = torch.randn(batch_size,
seq_len, seq_len,
...@@ -72,7 +72,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, ...@@ -72,7 +72,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
# if we have a contiguous head stride, test the alternate # if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout # [..., num_heads * head_dim] shape/layout
if head_stride_is_contingous: if head_stride_is_contiguous:
rotary_embedding_opcheck( rotary_embedding_opcheck(
rot, positions, query.flatten(start_dim=-2), rot, positions, query.flatten(start_dim=-2),
key.flatten(start_dim=-2) if use_key else None) key.flatten(start_dim=-2) if use_key else None)
...@@ -107,15 +107,15 @@ def generate_random_inputs(batch_size, ...@@ -107,15 +107,15 @@ def generate_random_inputs(batch_size,
return A, dt, X, B, C return A, dt, X, B, C
def generate_continous_batched_examples(example_lens_by_batch, def generate_continuous_batched_examples(example_lens_by_batch,
num_examples, num_examples,
full_length, full_length,
last_taken, last_taken,
exhausted, exhausted,
n_heads, n_heads,
d_head, d_head,
itype, itype,
device='cuda'): device='cuda'):
# this function generates a random examples of certain length # this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed # and then cut according to "example_lens_by_batch" and feed
...@@ -269,11 +269,10 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -269,11 +269,10 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None states = None
for Y_min, cu_seqlens, seq_idx, (A, dt, X, B, for Y_min, cu_seqlens, seq_idx, (
C) in generate_continous_batched_examples( A, dt, X, B, C) in generate_continuous_batched_examples(
cases, num_examples, seqlen, cases, num_examples, seqlen, last_taken, exhausted, n_heads,
last_taken, exhausted, n_heads, d_head, itype):
d_head, itype):
chunk_indices, chunk_offsets = \ chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets( _query_start_loc_to_chunk_indices_offsets(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
DeepEP test utilities DeepEP test utilities
""" """
import dataclasses import dataclasses
import importlib import importlib
import os
import traceback import traceback
from typing import Callable, Optional from typing import Callable, Optional
...@@ -13,6 +15,8 @@ from torch.multiprocessing import ( ...@@ -13,6 +15,8 @@ from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage] spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port
has_deep_ep = importlib.util.find_spec("deep_ep") is not None has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep: if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
...@@ -92,7 +96,7 @@ def parallel_launch( ...@@ -92,7 +96,7 @@ def parallel_launch(
world_size, world_size,
world_size, world_size,
0, 0,
"tcp://localhost:29500", f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker, worker,
) + args, ) + args,
nprocs=world_size, nprocs=world_size,
...@@ -134,18 +138,14 @@ def make_deepep_ht_a2a(pg: ProcessGroup, ...@@ -134,18 +138,14 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
low_latency_mode=low_latency_mode, low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank) num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer, return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size, num_dispatchers=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size, dp_size=dp_size,
rank_expert_offset=pgi.rank * rank_expert_offset=pgi.rank *
ht_args.num_local_experts, ht_args.num_local_experts)
quant_dtype=q_dtype,
block_shape=block_shape)
def make_deepep_ll_a2a(pg: ProcessGroup, def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo, pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs, deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None, q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None): block_shape: Optional[list[int]] = None):
...@@ -165,11 +165,8 @@ def make_deepep_ll_a2a(pg: ProcessGroup, ...@@ -165,11 +165,8 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
return DeepEPLLPrepareAndFinalize( return DeepEPLLPrepareAndFinalize(
buffer=buffer, buffer=buffer,
world_size=pgi.world_size, num_dispatchers=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
quant_dtype=q_dtype,
block_shape=block_shape,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
) )
...@@ -187,5 +184,4 @@ def make_deepep_a2a(pg: ProcessGroup, ...@@ -187,5 +184,4 @@ def make_deepep_a2a(pg: ProcessGroup,
block_shape) block_shape)
assert deepep_ll_args is not None assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype, return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)
block_shape)
...@@ -2,18 +2,57 @@ ...@@ -2,18 +2,57 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import pytest import pytest
import torch import torch
import triton.language as tl import triton.language as tl
from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, naive_batched_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel) invoke_moe_batched_triton_kernel)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
MNK_FACTORS = [
(1, 128, 128),
(1, 128, 2048),
(1, 512, 512),
(1, 1024, 128),
(1, 1024, 2048),
(32, 128, 128),
(32, 512, 512),
(32, 1024, 2048),
(45, 128, 128),
(45, 128, 2048),
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 1024, 128),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclass @dataclass
class BatchedMMConfig: class BatchedMMConfig:
dtype: torch.dtype in_dtype: torch.dtype
quant_dtype: Optional[torch.dtype]
out_dtype: torch.dtype
num_experts: int num_experts: int
max_tokens_per_expert: int max_tokens_per_expert: int
K: int K: int
...@@ -32,79 +71,129 @@ class BatchedMMTensors: ...@@ -32,79 +71,129 @@ class BatchedMMTensors:
A = torch.randn( A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K), (config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda", device="cuda",
dtype=config.dtype) / 10 dtype=config.in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K), B = torch.randn((config.num_experts, config.N, config.K),
device="cuda", device="cuda",
dtype=config.dtype) dtype=config.in_dtype)
C = torch.zeros( C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N), (config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda", device="cuda",
dtype=config.dtype) dtype=config.out_dtype)
num_expert_tokens = torch.randint(low=0, num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert, high=config.max_tokens_per_expert,
size=(config.num_experts, ), size=(config.num_experts, ),
device="cuda", device="cuda",
dtype=torch.int32) dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens) return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @pytest.mark.parametrize("num_experts", [8, 16, 32])
num_expert_tokens: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 1024])
@pytest.mark.parametrize(
"dtype",
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool):
current_platform.seed_everything(7)
num_expert_tokens_cpu = num_expert_tokens.clone() use_fp8_w8a8 = dtype == torch.float8_e4m3fn
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts): if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
num_tokens = num_expert_tokens_cpu[e] pytest.skip("Don't test blocking for non-quantized types.")
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization test.")
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
@pytest.mark.parametrize("num_experts", [16, 32]) num_expert_tokens = torch.randint(low=0,
@pytest.mark.parametrize("max_tokens_per_expert", high=max_tokens_per_expert,
[32, 64, 128, 192, 224, 256, 512]) size=(num_experts, ),
@pytest.mark.parametrize("K", [128, 256, 1024]) device="cuda",
@pytest.mark.parametrize("N", [128, 256, 512, 1024]) dtype=torch.int32)
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) A, A_q, A_scale = make_quantized_test_activations(
tensors = BatchedMMTensors.make_tensors(config) num_experts,
max_tokens_per_expert,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
test_output = tensors.C B, B_q, B_scale, _, _, _ = make_test_weights(
ref_output = test_output.clone() num_experts,
N // 2,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
out_shape = (num_experts, max_tokens_per_expert, N)
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
compute_tl_dtype = { compute_tl_dtype = {
torch.float16: tl.float16, torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16, torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32 torch.float32: tl.float32
}[test_output.dtype] }[test_output.dtype]
assert A_q.dtype == B_q.dtype
invoke_moe_batched_triton_kernel( invoke_moe_batched_triton_kernel(
tensors.A, A_q,
tensors.B, B_q,
test_output, test_output,
tensors.num_expert_tokens, num_expert_tokens,
compute_tl_dtype, compute_tl_dtype,
# Quantization data # Quantization data
None, A_scale,
None, B_scale,
None, None,
# Quantization schemes # Quantization schemes
False, use_fp8_w8a8,
False, False,
False, False,
config={ config={
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16, "BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16 "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
}) },
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
ref_output = ref_impl(tensors.A, tensors.B, ref_output, ref_output = native_batched_masked_quant_matmul(
tensors.num_expert_tokens) A,
B,
ref_output,
num_expert_tokens,
)
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape,
per_act_token_quant)
rtol, atol = { rtol, atol = {
torch.float16: (6e-2, 6e-2), torch.float16: (6e-2, 6e-2),
...@@ -112,4 +201,122 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ...@@ -112,4 +201,122 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch.float32: (1e-2, 1e-2), torch.float32: (1e-2, 1e-2),
}[test_output.dtype] }[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
input_scales: bool,
):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if topk > e:
pytest.skip("topk > e")
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization test.")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights(
e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
)
if input_scales and quant_dtype is not None:
a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
else:
a1_scale = None
a2_scale = None
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
batched_output = naive_batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
triton_output = batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
torch.testing.assert_close(batched_output,
baseline_output,
atol=3e-2,
rtol=2e-2)
torch.testing.assert_close(triton_output,
batched_output,
atol=2e-2,
rtol=2e-2)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4608, 128),
(1, 4608, 512),
(1, 4608, 7168),
(83, 128, 128),
(83, 512, 512),
(83, 1024, 7168),
(83, 4608, 512),
(83, 4608, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4608, 512),
(128, 4608, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4608, 512),
(2048, 4608, 7168),
(8192, 128, 128),
(8192, 512, 512),
(8192, 128, 7168),
(8192, 1024, 7168),
(8192, 4608, 512),
(8192, 4608, 7168),
]
MNK_FACTORS_DG = [
(128, 128, 128),
(128, 512, 512),
(128, 128, 7168),
(128, 1024, 7168),
(128, 4608, 128),
(128, 4608, 512),
(128, 4608, 7168),
(192, 128, 128),
(192, 512, 512),
(192, 1024, 7168),
(192, 4608, 512),
(192, 4608, 7168),
(1335, 128, 128),
(1335, 1024, 7168),
(1335, 4608, 512),
(1335, 4608, 7168),
(2048, 128, 128),
(2048, 512, 512),
(2048, 128, 7168),
(2048, 1024, 7168),
(2048, 4608, 128),
(2048, 4608, 512),
(2048, 4608, 7168),
]
BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16] # [128, 256]
TOP_KS = [1, 2, 6]
SEEDS = [0]
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.size(1)
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")
@pytest.fixture(autouse=True)
def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=False,
block_shape=block_size)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(
a,
w1,
w2,
w1_s,
w2_s,
topk_weights,
topk_ids,
block_size,
)
out = fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
m_out = m_fused_moe(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol = 0.035 if M < 40000 else 0.039
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
chunk_size = 1024
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids, block_size)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16]
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4096, 128),
(1, 4096, 512),
(1, 4096, 7168),
(33, 128, 128),
(33, 512, 512),
(33, 128, 7168),
(33, 1024, 7168),
(33, 4096, 128),
(33, 4096, 512),
(33, 4096, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4096, 512),
(128, 4096, 7168),
(222, 128, 128),
(222, 512, 512),
(222, 1024, 7168),
(222, 4096, 512),
(222, 4096, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4096, 512),
(2048, 4096, 7168),
]
E = [8, 24]
TOP_KS = [2, 6]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
# For test
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using
native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
def setup_cuda():
"""Sets the default CUDA device for all tests in this module."""
torch.set_default_device("cuda")
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch.manual_seed(seed)
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.int8,
per_act_token_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
# Check results
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# DeepGEMM Style Cutlass Grouped GEMM Test
# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py
import random
import pytest
import torch
from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
def cdiv(a, b):
return (a + b - 1) // b
def per_token_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view *
(448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128),
device=x.device,
dtype=x.dtype)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),
(8, 4096, 7168, 4096),
(8, 4096, 2048, 7168),
(32, 1024, 7168, 4096),
(32, 1024, 2048, 7168),
])
@pytest.mark.parametrize("out_dtype", [torch.float16])
@pytest.mark.skipif(
(lambda x: x is None or x.to_int() != 100)(
current_platform.get_device_capability()),
reason="Block Scaled Grouped GEMM is only supported on SM100.")
def test_cutlass_grouped_gemm(
num_groups: int,
expected_m_per_group: int,
k: int,
n: int,
out_dtype: torch.dtype,
):
device = "cuda"
alignment = 128
group_ms = [
int(expected_m_per_group * random.uniform(0.7, 1.3))
for _ in range(num_groups)
]
m = sum([cdiv(m, alignment) * alignment for m in group_ms])
x = torch.randn((m, k), device=device, dtype=out_dtype)
y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype)
out = torch.empty((m, n), device=device, dtype=out_dtype)
ref_out = torch.randn((m, n), device=device, dtype=out_dtype)
ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m]
pb_size = []
for i in range(num_groups):
pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k])
problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32)
expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, cdiv(n, 128), k // 128),
device=device,
dtype=torch.float))
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
for i in range(num_groups):
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]]
b = y_fp8[0][i].t()
b_scale = y_fp8[1][i].t()
baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline
ops.cutlass_blockwise_scaled_grouped_mm(
out,
x_fp8[0],
y_fp8[0],
x_fp8[1],
y_fp8[1],
problem_sizes,
expert_offsets[:-1],
)
torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3)
...@@ -29,6 +29,10 @@ MNK_FACTORS = [ ...@@ -29,6 +29,10 @@ MNK_FACTORS = [
(224, 1024, 1536), (224, 1024, 1536),
(224, 3072, 1024), (224, 3072, 1024),
(224, 3072, 1536), (224, 3072, 1536),
(32768, 1024, 1024),
# These sizes trigger wrong answers.
#(7232, 2048, 5120),
#(40000, 2048, 5120),
] ]
vllm_config = VllmConfig(parallel_config=ParallelConfig( vllm_config = VllmConfig(parallel_config=ParallelConfig(
...@@ -93,11 +97,9 @@ class MOETensors8Bit(MOETensors): ...@@ -93,11 +97,9 @@ class MOETensors8Bit(MOETensors):
n_b_scales = 2 * n if per_out_channel else 1 n_b_scales = 2 * n if per_out_channel else 1
k_b_scales = k if per_out_channel else 1 k_b_scales = k if per_out_channel else 1
# Get the right scale for tests. # Get the right scale for tests.
_, a_scale = ops.scaled_fp8_quant( a_q, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token) moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
a_scale,
use_per_token_if_dynamic=per_act_token)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
...@@ -183,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, ...@@ -183,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def run_8_bit(moe_tensors: MOETensors8Bit, def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
per_act_token: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor: num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([ assert not any([
t is None for t in [ t is None for t in [
...@@ -199,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ...@@ -199,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids, 'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale, 'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale, 'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale 'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
} }
num_experts = moe_tensors.w1.size(0) num_experts = moe_tensors.w1.size(0)
...@@ -231,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph( ...@@ -231,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
topk: int, topk: int,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
monkeypatch,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch) per_out_ch)
...@@ -248,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph( ...@@ -248,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids) topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids) cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch.testing.assert_close(triton_output, torch.testing.assert_close(triton_output,
cutlass_output, cutlass_output,
atol=5e-2, atol=5.5e-2,
rtol=1e-2) rtol=1e-2)
...@@ -273,8 +281,10 @@ def test_cutlass_moe_8_bit_cuda_graph( ...@@ -273,8 +281,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
topk: int, topk: int,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
monkeypatch,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
dtype = torch.half dtype = torch.half
...@@ -295,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph( ...@@ -295,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream): with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids) cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token)
torch.cuda.synchronize() torch.cuda.synchronize()
graph.replay() graph.replay()
...@@ -328,8 +339,10 @@ def test_cutlass_moe_8_bit_EP( ...@@ -328,8 +339,10 @@ def test_cutlass_moe_8_bit_EP(
per_act_token: bool, per_act_token: bool,
per_out_channel: bool, per_out_channel: bool,
ep_size: int, ep_size: int,
monkeypatch,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel) per_out_channel)
...@@ -349,6 +362,7 @@ def test_cutlass_moe_8_bit_EP( ...@@ -349,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
cutlass_output = run_8_bit(mt, cutlass_output = run_8_bit(mt,
topk_weights, topk_weights,
topk_ids, topk_ids,
per_act_token,
num_local_experts=e // ep_size) num_local_experts=e // ep_size)
torch.testing.assert_close(triton_output, torch.testing.assert_close(triton_output,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Test DeepEP + DeepGEMM integration Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case. fp8 block-quantized case.
""" """
import dataclasses import dataclasses
import importlib
from typing import Optional from typing import Optional
import pytest import pytest
...@@ -18,41 +18,34 @@ from vllm.config import VllmConfig, set_current_vllm_config ...@@ -18,41 +18,34 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from .deepep_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
has_deep_ep = importlib.util.find_spec("deep_ep") is not None if has_deep_ep():
try:
import deep_gemm
has_deep_gemm = True
except ImportError:
has_deep_gemm = False
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize) DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm():
if has_deep_gemm:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts) DeepGemmExperts)
requires_deep_ep = pytest.mark.skipif( requires_deep_ep = pytest.mark.skipif(
not has_deep_ep, not has_deep_ep(),
reason="Requires deep_ep kernels", reason="Requires deep_ep kernels",
) )
requires_deep_gemm = pytest.mark.skipif( requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm, not has_deep_gemm(),
reason="Requires deep_gemm kernels", reason="Requires deep_gemm kernels",
) )
...@@ -66,25 +59,6 @@ def next_power_of_2(x): ...@@ -66,25 +59,6 @@ def next_power_of_2(x):
return 2**math.ceil(math.log2(x)) return 2**math.ceil(math.log2(x))
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights( def make_block_quant_fp8_weights(
e: int, e: int,
n: int, n: int,
...@@ -92,43 +66,11 @@ def make_block_quant_fp8_weights( ...@@ -92,43 +66,11 @@ def make_block_quant_fp8_weights(
block_size: list[int], block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale Return weights w1q, w2q, w1_scale, w2_scale
""" """
dtype = torch.bfloat16 w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
fp8_info = torch.finfo(torch.float8_e4m3fn) return w1q, w2q, w1_scale, w2_scale
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device="cuda",
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device="cuda",
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
return w1, w2, w1_s, w2_s
@dataclasses.dataclass @dataclasses.dataclass
...@@ -138,6 +80,7 @@ class TestConfig: ...@@ -138,6 +80,7 @@ class TestConfig:
k: int k: int
n: int n: int
num_experts: int num_experts: int
per_act_token_quant: bool
block_size: list[int] block_size: list[int]
# configs for testing low-latency kernels # configs for testing low-latency kernels
low_latency: bool low_latency: bool
...@@ -156,8 +99,7 @@ class TestTensors: ...@@ -156,8 +99,7 @@ class TestTensors:
def make(config: TestConfig, rank) -> "TestTensors": def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16 dtype = torch.bfloat16
topk, m, k, block_size = (config.topk, config.m, config.k, topk, m, k = (config.topk, config.m, config.k)
config.block_size)
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
...@@ -165,9 +107,7 @@ class TestTensors: ...@@ -165,9 +107,7 @@ class TestTensors:
rank_tokens = torch.randn( rank_tokens = torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 (m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max) rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
rank_token_scales = None
block_k = block_size[1]
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
topk_ids = torch.randint( topk_ids = torch.randint(
low=0, low=0,
...@@ -207,10 +147,11 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -207,10 +147,11 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype=q_dtype, q_dtype=q_dtype,
block_shape=test_config.block_size) block_shape=test_config.block_size)
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank, fused_experts = BatchedDeepGemmExperts(
world_size=pgi.world_size, max_num_tokens=max_tokens_per_rank,
dp_size=dp_size, num_dispatchers=pgi.world_size // dp_size,
block_shape=test_config.block_size) block_shape=test_config.block_size,
per_act_token_quant=test_config.per_act_token_quant)
mk = FusedMoEModularKernel(prepare_finalize=a2a, mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts) fused_experts=fused_experts)
return mk return mk
...@@ -432,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, ...@@ -432,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
""" """
Tests for High-Throughput DeepEP + DeepGemm integration. Tests for High-Throughput DeepEP + DeepGemm integration.
""" """
import deep_gemm
m, n, k = mnk m, n, k = mnk
current_platform.seed_everything(7) current_platform.seed_everything(7)
...@@ -448,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, ...@@ -448,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
k=k, k=k,
n=n, n=n,
num_experts=num_experts, num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size, block_size=block_size,
low_latency=False, low_latency=False,
use_fp8_dispatch=None) use_fp8_dispatch=None)
...@@ -480,10 +423,14 @@ USE_FP8_DISPATCH = [False] ...@@ -480,10 +423,14 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int, def test_ll_deepep_deepgemm_moe(
int], num_experts: int, topk: int, mnk: tuple[int, int, int],
use_fp8_dispatch: bool, block_size: list[int], num_experts: int,
world_dp_size: tuple[int, int]): topk: int,
use_fp8_dispatch: bool,
block_size: list[int],
world_dp_size: tuple[int, int],
):
""" """
Tests for Low-Latency DeepEP + DeepGemm integration. Tests for Low-Latency DeepEP + DeepGemm integration.
""" """
...@@ -501,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int, ...@@ -501,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
k=k, k=k,
n=n, n=n,
num_experts=num_experts, num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size, block_size=block_size,
low_latency=True, low_latency=True,
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Test deepep dispatch-combine logic Test deepep dispatch-combine logic
""" """
import dataclasses import dataclasses
import importlib
from typing import Optional, Union from typing import Optional, Union
import pytest import pytest
...@@ -22,21 +22,20 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -22,21 +22,20 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
from .deepep_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
has_deep_ep = importlib.util.find_spec("deep_ep") is not None if has_deep_ep():
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize) DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif( requires_deep_ep = pytest.mark.skipif(
not has_deep_ep, not has_deep_ep(),
reason="Requires deep_ep kernels", reason="Requires deep_ep kernels",
) )
...@@ -104,10 +103,6 @@ class TestTensors: ...@@ -104,10 +103,6 @@ class TestTensors:
rank_tokens = torch.randn( rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10 (config.m, config.k), device="cuda", dtype=token_dtype) / 10
rank_token_scales = None rank_token_scales = None
if config.dtype == torch.float8_e4m3fn:
# low_latency_mode kernels dont support per-token quant.
_, rank_token_scales = ops.scaled_fp8_quant(
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
topk = torch.randint(low=0, topk = torch.randint(low=0,
high=config.num_experts, high=config.num_experts,
...@@ -123,11 +118,18 @@ class TestTensors: ...@@ -123,11 +118,18 @@ class TestTensors:
config=config) config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, def make_modular_kernel(
low_latency_mode: bool, hidden_size: int, dp_size: int, pg: ProcessGroup,
num_experts: int, num_local_experts: int, pgi: ProcessGroupInfo,
q_dtype: Optional[torch.dtype], low_latency_mode: bool,
use_fp8_dispatch: bool) -> FusedMoEModularKernel: hidden_size: int,
dp_size: int,
num_experts: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None is_quantized = q_dtype is not None
...@@ -153,33 +155,47 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -153,33 +155,47 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
deepep_ht_args = ht_args, deepep_ht_args = ht_args,
deepep_ll_args = ll_args) deepep_ll_args = ll_args)
num_dispatchers = pgi.world_size // dp_size
if low_latency_mode: if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts( fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK, max_num_tokens=MAX_TOKENS_PER_RANK,
world_size=pgi.world_size, num_dispatchers=num_dispatchers,
dp_size=dp_size,
use_fp8_w8a8=is_quantized, use_fp8_w8a8=is_quantized,
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False) use_int4_w4a16=False,
per_act_token_quant=False,
)
else: else:
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized, fused_experts = TritonExperts(
use_int8_w8a8=False, use_fp8_w8a8=is_quantized,
use_int8_w8a16=False, use_int8_w8a8=False,
use_int4_w4a16=False, use_int8_w8a16=False,
per_channel_quant=False) use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts) fused_experts=fused_experts)
return mk return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, def deep_ep_moe_impl(
low_latency_mode: bool, dp_size: int, pg: ProcessGroup,
test_tensors: TestTensors, w1: torch.Tensor, pgi: ProcessGroupInfo,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor], low_latency_mode: bool,
w2_scale: Optional[torch.Tensor], num_experts: int, dp_size: int,
use_fp8_dispatch: bool) -> torch.Tensor: test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int,
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> torch.Tensor:
num_local_experts = w1.size(0) num_local_experts = w1.size(0)
...@@ -201,11 +217,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -201,11 +217,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype = torch.float8_e4m3fn q_dtype = torch.float8_e4m3fn
# Make modular kernel # Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode, mk: FusedMoEModularKernel = make_modular_kernel(
hidden_size, dp_size, pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_experts, num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
num_local_experts, q_dtype,
use_fp8_dispatch)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens) out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0) total_num_tokens = test_tensors.rank_tokens.size(0)
...@@ -259,9 +273,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -259,9 +273,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return out_hidden_states return out_hidden_states
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor, def torch_moe_impl(
w2: torch.Tensor, w1_scale: Optional[torch.Tensor], test_tensors: TestTensors,
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool): w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
using_fp8_dispatch: bool,
per_act_token_quant: bool,
):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk, a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights) test_tensors.topk_weights)
...@@ -269,6 +289,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor, ...@@ -269,6 +289,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
# The DeepEP implementation is requested to dispatch using FP8. # The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by # For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant. # blockwise quant and de-quant.
assert not per_act_token_quant
a = test_tensors.rank_tokens a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128) aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view( a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
...@@ -312,6 +333,7 @@ def _deep_ep_moe( ...@@ -312,6 +333,7 @@ def _deep_ep_moe(
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
use_fp8_dispatch: bool, use_fp8_dispatch: bool,
per_act_token_quant: bool,
): ):
if not low_latency_mode: if not low_latency_mode:
...@@ -333,7 +355,8 @@ def _deep_ep_moe( ...@@ -333,7 +355,8 @@ def _deep_ep_moe(
with set_current_vllm_config(VllmConfig()): with set_current_vllm_config(VllmConfig()):
# Reference # Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale, torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch) w2_scale, use_fp8_dispatch,
per_act_token_quant)
# Splice experts for this rank. # Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size num_local_experts = config.num_experts // pgi.world_size
...@@ -358,6 +381,7 @@ def _deep_ep_moe( ...@@ -358,6 +381,7 @@ def _deep_ep_moe(
w2_scale_ep, w2_scale_ep,
config.num_experts, config.num_experts,
use_fp8_dispatch, use_fp8_dispatch,
per_act_token_quant,
) )
torch.testing.assert_close( torch.testing.assert_close(
...@@ -386,10 +410,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] ...@@ -386,10 +410,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@requires_deep_ep @requires_deep_ep
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], def test_deep_ep_moe(
num_experts: int, topk: int, world_dp_size: tuple[int, dtype: torch.dtype,
int]): mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
):
low_latency_mode = False low_latency_mode = False
use_fp8_dispatch = False use_fp8_dispatch = False
m, n, k = mnk m, n, k = mnk
...@@ -406,7 +436,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], ...@@ -406,7 +436,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
per_act_token_quant)
MNKs = [ MNKs = [
...@@ -456,4 +487,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], ...@@ -456,4 +487,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
False)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit-test DeepGEMM FP8 kernels (no DeepEP).
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
"""
import importlib
import math
import pytest
import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import cdiv
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
if has_deep_gemm:
import deep_gemm
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm,
reason="Requires deep_gemm kernels",
)
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
):
"""
Generate (w1, w2) expert weights and their per-block scale tensors
in FP8 block-quantized format.
w1 shape: (E, 2N, K)
w2 shape: (E, K, N)
"""
dtype = torch.bfloat16
fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
torch.float8_e4m3fn).min
# bf16 reference weights
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10
w1_bf16.clamp_(fp8_min, fp8_max)
w2_bf16.clamp_(fp8_min, fp8_max)
block_n, block_k = block_size
n_tiles_w1 = math.ceil((2 * n) / block_n)
k_tiles_w1 = math.ceil(k / block_k)
n_tiles_w2 = math.ceil(k / block_n)
k_tiles_w2 = math.ceil(n / block_k)
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty(e,
n_tiles_w1,
k_tiles_w1,
device="cuda",
dtype=torch.float32)
w2_s = torch.empty(e,
n_tiles_w2,
k_tiles_w2,
device="cuda",
dtype=torch.float32)
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
return w1, w2, w1_s, w2_s
def run_single_case(m, n, k, topk, num_experts, block_size):
"""
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
Triton baseline within tolerance.
"""
tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
# expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
block_size)
router_logits = torch.randn(m,
num_experts,
device="cuda",
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
# triton referrence
out_triton = fused_experts(
hidden_states=tokens_bf16,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
allow_deep_gemm=False,
)
# DeepGemm
out_deepgemm = fused_experts(
hidden_states=tokens_bf16,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
allow_deep_gemm=True,
)
base = out_triton.abs().mean()
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
rtol = 0.05
# ----- Compare -----
torch.testing.assert_close(
out_deepgemm.to(torch.float32),
out_triton.to(torch.float32),
rtol=rtol,
atol=float(atol),
)
# Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path.
MNKs = [
(1024, 512, 128),
(1024, 512, 512),
(2048, 512, 512),
(512, 1024, 1024),
(512, 2048, 2048),
(4096, 4096, 1024),
]
TOPKS = [2, 6]
NUM_EXPERTS = [32]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@requires_deep_gemm
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe")
call_counter = {"cnt": 0}
orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
def _spy_deep_gemm_moe_fp8(*args, **kwargs):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
_spy_deep_gemm_moe_fp8)
m, n, k = mnk
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")
run_single_case(
m=m,
n=n,
k=k,
topk=topk,
num_experts=num_experts,
block_size=BLOCK_SIZE,
)
# ensure that the DeepGEMM path was indeed taken.
assert call_counter["cnt"] == 1, \
f"DeepGEMM path was not executed during the test. " \
f"Call counter: {call_counter['cnt']}"
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
import functools
from typing import Callable, Optional, Union
import pytest import pytest
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
...@@ -14,8 +17,11 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock ...@@ -14,8 +17,11 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe) fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
...@@ -39,7 +45,76 @@ vllm_config.scheduler_config.max_num_seqs = 128 ...@@ -39,7 +45,76 @@ vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192 vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) def run_moe_test(
baseline: Union[Callable, torch.Tensor],
moe_fn: Callable,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
padding: bool = False,
use_compile: bool = False,
use_cudagraph: bool = False,
atol: float = 2e-2,
rtol: float = 0,
) -> torch.Tensor:
if isinstance(baseline, torch.Tensor):
baseline_output = baseline
else:
baseline_output = baseline(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
if use_compile:
moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(score, 0)
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
if use_cudagraph:
test_output.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(test_output,
baseline_output,
atol=atol,
rtol=rtol)
return baseline_output
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
...@@ -47,6 +122,7 @@ vllm_config.scheduler_config.max_model_len = 8192 ...@@ -47,6 +122,7 @@ vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_fused_moe( def test_fused_moe(
m: int, m: int,
n: int, n: int,
...@@ -56,7 +132,21 @@ def test_fused_moe( ...@@ -56,7 +132,21 @@ def test_fused_moe(
ep_size: int, ep_size: int,
dtype: torch.dtype, dtype: torch.dtype,
padding: bool, padding: bool,
chunk_size: int,
monkeypatch,
): ):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
#
# Setup test data
#
#
# Setup test data
#
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
...@@ -76,38 +166,70 @@ def test_fused_moe( ...@@ -76,38 +166,70 @@ def test_fused_moe(
else: else:
e_map = None e_map = None
with set_current_vllm_config(vllm_config): #
torch_output = torch_moe(a, w1, w2, score, topk, e_map) # Setup test functions
iterative_output = iterative_moe(a, #
w1,
w2, m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
score, use_int8_w8a8=False,
topk, use_int8_w8a16=False,
global_num_experts=e, use_int4_w4a16=False,
expert_map=e_map, per_act_token_quant=False,
renormalize=False) block_shape=None)
def m_fused_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map)
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
#
# Run tests
#
runner = functools.partial(
run_moe_test,
a=a,
w1=w1,
w2=w2,
score=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
padding=padding,
)
# Pad the weight if moe padding is enabled # Note: for now use_compile will error out if the problem size is
if padding: # large enough to trigger chunking. I'm leaving the flag and
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] # setup code in case we are able to revisit this later.
torch.cuda.empty_cache() use_compile = False
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
triton_output = fused_moe(a, use_cudagraph = (n >= 1024 and k >= 1024
w1, and current_platform.is_cuda_alike())
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) with set_current_vllm_config(vllm_config):
torch.testing.assert_close(iterative_output, baseline_output = runner(torch_moe, iterative_moe)
torch_output, runner(baseline_output,
atol=2e-2, fused_moe_fn,
rtol=0) use_compile=use_compile,
use_cudagraph=use_cudagraph)
runner(baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph)
@pytest.mark.parametrize("m", [1, 32, 222]) @pytest.mark.parametrize("m", [1, 32, 222])
...@@ -217,7 +339,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -217,7 +339,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w1_zp=w1_qzeros if has_zp else None, w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size]) block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) torch_output = torch_moe(a,
w1_ref,
w2_ref,
score,
topk,
expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
...@@ -243,46 +370,59 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -243,46 +370,59 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
if dtype == torch.float32: if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32") pytest.skip("AITER ROCm test skip for float32")
monkeypatch.setenv('RANK', "0")
monkeypatch.setenv('LOCAL_RANK', "0")
monkeypatch.setenv('WORLD_SIZE', "1")
monkeypatch.setenv('MASTER_ADDR', 'localhost')
monkeypatch.setenv('MASTER_PORT', '12345')
init_distributed_environment()
# Instantiate our and huggingface's MoE blocks # Instantiate our and huggingface's MoE blocks
config = MixtralConfig() vllm_config.compilation_config.static_forward_context = dict()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") with (set_current_vllm_config(vllm_config),
vllm_moe = MixtralMoE( set_forward_context(None, vllm_config)):
num_experts=config.num_local_experts, config = MixtralConfig()
top_k=config.num_experts_per_tok, hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
hidden_size=config.hidden_size, vllm_moe = MixtralMoE(
intermediate_size=config.intermediate_size, num_experts=config.num_local_experts,
params_dtype=dtype, top_k=config.num_experts_per_tok,
tp_size=1, hidden_size=config.hidden_size,
dp_size=1, intermediate_size=config.intermediate_size,
).cuda() params_dtype=dtype,
tp_size=1,
# Load the weights dp_size=1,
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data ).cuda()
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data, # Load the weights
hf_moe.experts[i].w3.weight.data) vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) for i in range(config.num_local_experts):
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim] vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1) # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn(
(1, 64, config.hidden_size)).to(dtype).to("cuda")
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled # Pad the weight if moe padding is enabled
if padding: if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad( vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
requires_grad=False) 0:-128],
torch.cuda.empty_cache() requires_grad=False)
vllm_moe.experts.w2_weight = Parameter(F.pad( torch.cuda.empty_cache()
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], vllm_moe.experts.w2_weight = Parameter(F.pad(
requires_grad=False) vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
torch.cuda.empty_cache() 0:-128],
requires_grad=False)
# Run forward passes for both MoE blocks torch.cuda.empty_cache()
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs) # Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)
mixtral_moe_tol = { mixtral_moe_tol = {
torch.float32: 1e-3, torch.float32: 1e-3,
...@@ -525,7 +665,12 @@ def test_fused_marlin_moe( ...@@ -525,7 +665,12 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) torch_output = torch_moe(a,
w_ref1,
w_ref2,
score,
topk,
expert_map=e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe( marlin_output = torch.ops.vllm.fused_marlin_moe(
a, a,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size_triton)
@pytest.mark.parametrize(
"block_size,num_tokens,topk,num_experts",
list(
itertools.product(
[32, 64, 128, 256], # block_size
[
1,
3,
7,
16,
256,
2256,
4096,
], # num_tokens
[1, 4, 16, 64], # topk
[64, 160, 256, 257, 260, 264], # num_experts
)),
)
def test_moe_align_block_size_compare_implementations(block_size, num_tokens,
topk, num_experts):
topk_ids = torch.stack([
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(num_tokens)
])
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_cuda = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad_cuda = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
sorted_ids_triton.fill_(topk_ids.numel())
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_cuda,
expert_ids_cuda,
num_tokens_post_pad_cuda,
)
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
)
assert torch.allclose(expert_ids_cuda, expert_ids_triton), (
f"Expert IDs mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA expert_ids: {expert_ids_cuda}\n"
f"Triton expert_ids: {expert_ids_triton}")
assert torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton), (
f"Num tokens post pad mismatch for block_size={block_size}, "
f"num_tokens={num_tokens}, topk={topk}\n"
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}")
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment