Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk ...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform from vllm.platforms import current_platform
if not current_platform.has_device_capability(100): if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True) allow_module_level=True)
MNK_FACTORS = [ MNK_FACTORS = [
...@@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
device=w2.device, device=w2.device,
block_size=quant_blocksize) block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output, torch.testing.assert_close(torch_output,
cutlass_output, cutlass_output,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest import pytest
import torch import torch
from tests.pplx_utils import ProcessGroupInfo, parallel_launch from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv
from .parallel_utils import ProcessGroupInfo, parallel_launch
try: try:
from pplx_kernels import AllToAll from pplx_kernels import AllToAll
...@@ -64,6 +68,7 @@ def pplx_cutlass_moe( ...@@ -64,6 +68,7 @@ def pplx_cutlass_moe(
out_dtype, out_dtype,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
group_name: Optional[str],
): ):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize)
...@@ -84,36 +89,46 @@ def pplx_cutlass_moe( ...@@ -84,36 +89,46 @@ def pplx_cutlass_moe(
else: else:
scale_elems = (hidden_dim + block_size - 1) // block_size scale_elems = (hidden_dim + block_size - 1) // block_size
ata = AllToAll.internode( args = dict(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
num_experts=num_experts, num_experts=num_experts,
experts_per_token=topk, experts_per_token=topk,
rank=rank, rank=rank,
world_size=pgi.world_size, world_size=world_size,
dp_size=dp_size, dp_size=dp_size,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1 hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize, hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
) )
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
w1 = w1.to(device) w1 = w1.to(device)
w2 = w2.to(device) w2 = w2.to(device)
w1_scale = w1_scale.to(device) w1_scale = w1_scale.to(device)
w2_scale = w2_scale.to(device) w2_scale = w2_scale.to(device)
a1_scale = a1_scale.to(device) a1_scale = a1_scale.to(device)
assert num_experts % world_size == 0
num_local_experts = cdiv(num_experts, world_size)
num_dispatchers = pgi.world_size // dp_size
prepare_finalize = PplxPrepareAndFinalize( prepare_finalize = PplxPrepareAndFinalize(
ata, ata,
max_num_tokens, max_num_tokens=max_num_tokens,
pgi.world_size, num_local_experts=num_local_experts,
rank, num_dispatchers=num_dispatchers)
dp_size,
quant_dtype=torch.float8_e4m3fn,
per_act_token=per_act_token,
)
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, experts = CutlassExpertsFp8(num_local_experts,
out_dtype, per_act_token, per_out_ch) out_dtype,
per_act_token,
per_out_ch,
num_dispatchers=num_dispatchers,
use_batched_format=True)
fused_cutlass_experts = FusedMoEModularKernel( fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
...@@ -151,22 +166,6 @@ vllm_config.scheduler_config.max_num_seqs = 128 ...@@ -151,22 +166,6 @@ vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192 vllm_config.scheduler_config.max_model_len = 8192
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def _pplx_moe( def _pplx_moe(
pgi: ProcessGroupInfo, pgi: ProcessGroupInfo,
dp_size: int, dp_size: int,
...@@ -184,30 +183,42 @@ def _pplx_moe( ...@@ -184,30 +183,42 @@ def _pplx_moe(
w2_full: torch.Tensor, w2_full: torch.Tensor,
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
use_internode: bool,
): ):
uid = nvshmem_get_unique_id( try:
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() if use_internode:
torch.distributed.broadcast(uid, src=0) uid = nvshmem_get_unique_id(
nvshmem_init(uid, pgi.rank, pgi.world_size) ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
with set_current_vllm_config(vllm_config): nvshmem_init(uid, pgi.rank, pgi.world_size)
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, else:
topk_ids) group_ranks = list(range(pgi.world_size))
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, cpu_group = torch.distributed.new_group(group_ranks,
w2_scale, topk_weights, topk_ids, backend="gloo")
a1_scale, out_dtype, per_act_token, group_name = cpu_group.group_name
per_out_ch)
with set_current_vllm_config(vllm_config):
torch_output = chunk_by_rank(torch_output, pgi.rank, torch_output = torch_experts(a_full, w1_full, w2_full,
pgi.world_size).to(pplx_output.device) topk_weights, topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
# Uncomment if more debugging is needed w2_scale, topk_weights, topk_ids,
# print("PPLX OUT:", pplx_output) a1_scale, out_dtype, per_act_token,
# print("TORCH OUT:", torch_output) per_out_ch, group_name)
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
nvshmem_finalize()
# Uncomment if more debugging is needed
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)
torch.testing.assert_close(pplx_output,
torch_output,
atol=0.05,
rtol=0)
finally:
if use_internode:
nvshmem_finalize()
@pytest.mark.parametrize("m", [2, 224]) @pytest.mark.parametrize("m", [2, 224])
...@@ -218,6 +229,7 @@ def _pplx_moe( ...@@ -218,6 +229,7 @@ def _pplx_moe(
@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.skipif( @pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()), current_platform.get_device_capability()),
...@@ -232,6 +244,7 @@ def test_cutlass_moe_pplx( ...@@ -232,6 +244,7 @@ def test_cutlass_moe_pplx(
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_internode: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
...@@ -284,4 +297,5 @@ def test_cutlass_moe_pplx( ...@@ -284,4 +297,5 @@ def test_cutlass_moe_pplx(
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q, parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1, w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch) dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
use_internode)
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
Run `pytest tests/kernels/test_pplx_moe.py`. Run `pytest tests/kernels/test_pplx_moe.py`.
""" """
from typing import Optional import itertools
import textwrap
import traceback
from typing import Callable, Optional
import pytest import pytest
import torch import torch
...@@ -18,39 +21,43 @@ try: ...@@ -18,39 +21,43 @@ try:
except ImportError: except ImportError:
has_pplx = False has_pplx = False
from tests.pplx_utils import ProcessGroupInfo, parallel_launch from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
from tests.kernels.quant_utils import dequant
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
get_default_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up
from .parallel_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif( requires_pplx = pytest.mark.skipif(
not has_pplx, not has_pplx,
reason="Requires PPLX kernels", reason="Requires PPLX kernels",
) )
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), PPLX_COMBOS = [
(222, 2048, 1024)] # TODO: figure out why this fails, seems to be test problem
#(1, 128, 128),
PPLX_MOE_COMBOS = [
(1, 128, 128),
(2, 128, 512), (2, 128, 512),
(3, 1024, 2048), (3, 1024, 2048),
(32, 128, 1024), (4, 128, 128),
(32, 1024, 512),
(45, 512, 2048), (45, 512, 2048),
(64, 1024, 1024), (64, 1024, 512),
(222, 1024, 2048), (222, 2048, 1024),
(256, 1408, 2048),
] ]
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [1, 2, 6] TOP_KS = [1, 2, 6]
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_num_seqs = 128
...@@ -143,45 +150,6 @@ def torch_batched_moe( ...@@ -143,45 +150,6 @@ def torch_batched_moe(
return torch_finalize(out, topk_weight, topk_ids) return torch_finalize(out, topk_weight, topk_ids)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("k", [128, 512, 1024])
...@@ -205,9 +173,11 @@ def test_fused_moe_batched_experts( ...@@ -205,9 +173,11 @@ def test_fused_moe_batched_experts(
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) baseline_output = torch_experts(a, w1, w2, topk_weight,
topk_ids) # only for baseline
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = naive_batched_moe(
a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this
torch.testing.assert_close(baseline_output, torch.testing.assert_close(baseline_output,
torch_output, torch_output,
...@@ -219,6 +189,63 @@ def test_fused_moe_batched_experts( ...@@ -219,6 +189,63 @@ def test_fused_moe_batched_experts(
rtol=0) rtol=0)
def create_pplx_prepare_finalize(
num_tokens: int,
hidden_dim: int,
topk: int,
num_experts: int,
rank: int,
dp_size: int,
world_size: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
group_name: Optional[str],
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
num_local_experts = rank_chunk(num_experts, 0, world_size)
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
max_num_tokens,
hidden_dim,
in_dtype,
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
args = dict(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=scale_bytes,
)
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens=max_num_tokens,
num_local_experts=num_local_experts,
num_dispatchers=world_size // dp_size,
)
return prepare_finalize, ata
def rank_chunk(num: int, r: int, w: int) -> int: def rank_chunk(num: int, r: int, w: int) -> int:
rem = num % w rem = num % w
return (num // w) + (1 if r < rem else 0) return (num // w) + (1 if r < rem else 0)
...@@ -229,70 +256,114 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: ...@@ -229,70 +256,114 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
return t[(r * chunk):(r + 1) * chunk] return t[(r * chunk):(r + 1) * chunk]
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int,
topk_weight: torch.Tensor, topk_ids: torch.Tensor, w: int) -> Optional[torch.Tensor]:
num_experts: int) -> torch.Tensor: if t is not None:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( return chunk_by_rank(t, r, w)
PplxPrepareAndFinalize) else:
return t
def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int,
w: int) -> Optional[torch.Tensor]:
if t is not None and t.numel() > 1:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
else:
return t
def chunk_scales(t: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if t is not None and t.numel() > 1:
return t[start:end]
else:
return t
def dummy_work(a: torch.Tensor) -> torch.Tensor:
return a * 1.1
def pplx_prepare_finalize(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
group_name: Optional[str],
) -> torch.Tensor:
assert torch.cuda.current_device() == pgi.local_rank assert torch.cuda.current_device() == pgi.local_rank
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
num_tokens, hidden_dim = a.shape num_tokens, hidden_dim = a.shape
block_size = 128
device = pgi.device device = pgi.device
rank = pgi.rank rank = pgi.rank
world_size = pgi.world_size world_size = pgi.world_size
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
)
topk_ids = topk_ids.to(dtype=torch.uint32) topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize( prepare_finalize, ata = create_pplx_prepare_finalize(
ata, num_tokens,
max_num_tokens, hidden_dim,
world_size, topk,
num_experts,
rank, rank,
dp_size, dp_size,
world_size,
a.dtype, a.dtype,
quant_dtype,
block_shape,
per_act_token_quant,
group_name,
) )
assert a.shape[0] == topk_ids.shape[0]
a_chunk = chunk_by_rank(a, rank, world_size).to(device) a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
assert a_chunk.shape[0] == chunk_topk_ids.shape[0]
out = torch.full(
a_chunk.shape,
torch.nan,
dtype=a.dtype,
device=device,
)
if (quant_dtype is not None and not per_act_token_quant
and block_shape is None):
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
else:
a1_scale = None
a2_scale = None
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare( b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk, a_chunk,
None, a1_scale,
None, a2_scale,
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
num_experts, num_experts,
None, None,
False, False,
FusedMoEQuantConfig(
quant_dtype,
per_act_token_quant,
False,
block_shape,
),
) )
b_a = b_a * 1.5 b_a = dummy_work(
dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
out = torch.full(
(max_num_tokens, hidden_dim),
torch.nan,
dtype=a.dtype,
device=device,
)
prepare_finalize.finalize( prepare_finalize.finalize(
out, out,
...@@ -318,61 +389,100 @@ def _pplx_prepare_finalize( ...@@ -318,61 +389,100 @@ def _pplx_prepare_finalize(
score: torch.Tensor, score: torch.Tensor,
topk: torch.Tensor, topk: torch.Tensor,
num_experts: int, num_experts: int,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
use_internode: bool,
): ):
uid = nvshmem_get_unique_id( try:
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() if use_internode:
torch.distributed.broadcast(uid, src=0) uid = nvshmem_get_unique_id(
nvshmem_init(uid, pgi.rank, pgi.world_size) ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
device = pgi.device torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
group_name = cpu_group.group_name
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
k = a.shape[1] m, k = a.shape
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
torch_output = (a_rep.view(-1, topk, k) * 1.5 * a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
a.dtype)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, torch_output = (a_rep.view(m, topk, k) *
num_experts) topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(
dim=1)
torch_output = chunk_by_rank(torch_output, pgi.rank, pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight,
pgi.world_size).to(pplx_output.device) topk_ids, num_experts, quant_dtype,
block_shape, per_act_token_quant,
group_name)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pgi.device)
nvshmem_finalize() torch.testing.assert_close(pplx_output,
torch_output,
atol=3e-2,
rtol=3e-2)
finally:
if use_internode:
nvshmem_finalize()
# TODO (bnell): this test point does not work for odd M due to how the test is @pytest.mark.parametrize("mnk", PPLX_COMBOS)
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.optional
@requires_pplx @requires_pplx
def test_pplx_prepare_finalize( def test_pplx_prepare_finalize_slow(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
e: int, e: int,
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
use_internode: bool,
): ):
if dtype == torch.float8_e4m3fn:
use_fp8_w8a8 = True
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
use_fp8_w8a8 = False
act_dtype = dtype
quant_dtype = None
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization combination")
current_platform.seed_everything(7) current_platform.seed_everything(7)
m, n, k = mnk m, n, k = mnk
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
device = "cuda" device = "cuda"
a = torch.randn((m, k), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype) a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
score = torch.randn((m, e), device=device, dtype=act_dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e) topk, e, quant_dtype, block_shape, per_act_token_quant,
use_internode)
def pplx_moe( def pplx_moe(
group_name: Optional[str],
rank: int, rank: int,
world_size: int, world_size: int,
dp_size: int, dp_size: int,
...@@ -381,65 +491,75 @@ def pplx_moe( ...@@ -381,65 +491,75 @@ def pplx_moe(
w2: torch.Tensor, w2: torch.Tensor,
topk_weight: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_compile: bool = True, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
use_compile: bool = False,
use_cudagraphs: bool = True, use_cudagraphs: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
device = torch.device("cuda", rank) num_tokens, hidden_dim = a.shape
hidden_dim = a.shape[1]
num_experts = w1.shape[0] num_experts = w1.shape[0]
block_size = 128
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size) max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 16)
ata = AllToAll.internode( prepare_finalize, ata = create_pplx_prepare_finalize(
max_num_tokens=max_num_tokens, num_tokens,
num_experts=num_experts, hidden_dim,
experts_per_token=topk, topk,
rank=rank, num_experts,
world_size=world_size, rank,
dp_size=dp_size, dp_size,
hidden_dim=hidden_dim, world_size,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize, a.dtype,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else quant_dtype,
((hidden_dim + block_size - 1) // block_size * block_shape,
torch.float32.itemsize)), per_act_token_quant,
group_name,
) )
topk_ids = topk_ids.to(dtype=torch.uint32) topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize( experts = BatchedTritonExperts(
ata, max_num_tokens=max_num_tokens,
max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(),
world_size, use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
rank, block_shape=block_shape,
dp_size, per_act_token_quant=per_act_token_quant,
) )
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
world_size=world_size,
dp_size=dp_size)
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
) )
# Note: workers with the same dp_rank must use the exact same inputs. # Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size).to(device) a_chunk = chunk_by_rank(a, rank, world_size)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size)
# Chunking weights like this only works for batched format # Chunking weights like this only works for batched format
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) w1_chunk = chunk_by_rank(w1, rank, world_size)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) w2_chunk = chunk_by_rank(w2, rank, world_size)
w1_scale_chunk = maybe_chunk_by_rank(w1_scale, rank, world_size)
w2_scale_chunk = maybe_chunk_by_rank(w2_scale, rank, world_size)
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
if use_compile: if use_compile:
_fused_experts = torch.compile(fused_experts, _fused_experts = torch.compile(fused_experts,
backend='inductor', backend='inductor',
fullgraph=True) fullgraph=True)
torch._dynamo.mark_dynamic(a_chunk, 0)
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
else: else:
_fused_experts = fused_experts _fused_experts = fused_experts
...@@ -448,6 +568,10 @@ def pplx_moe( ...@@ -448,6 +568,10 @@ def pplx_moe(
w2_chunk, w2_chunk,
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts) global_num_experts=num_experts)
if use_cudagraphs: if use_cudagraphs:
...@@ -460,6 +584,10 @@ def pplx_moe( ...@@ -460,6 +584,10 @@ def pplx_moe(
w2_chunk, w2_chunk,
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts) global_num_experts=num_experts)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -472,48 +600,6 @@ def pplx_moe( ...@@ -472,48 +600,6 @@ def pplx_moe(
return out return out
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
assert torch.cuda.current_device() == pgi.local_rank
num_experts = w1.shape[0]
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
prepare_finalize = BatchedPrepareAndFinalize(
max_num_tokens=max_num_tokens,
world_size=world_size,
dp_size=dp_size,
rank=rank,
)
experts = BatchedExperts(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
out = fused_experts(
a_chunk,
# Chunking weights like this only works for batched format
chunk_by_rank(w1, rank, world_size).to(device),
chunk_by_rank(w2, rank, world_size).to(device),
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
return out
def _pplx_moe( def _pplx_moe(
pgi: ProcessGroupInfo, pgi: ProcessGroupInfo,
dp_size: int, dp_size: int,
...@@ -522,54 +608,287 @@ def _pplx_moe( ...@@ -522,54 +608,287 @@ def _pplx_moe(
w2: torch.Tensor, w2: torch.Tensor,
score: torch.Tensor, score: torch.Tensor,
topk: int, topk: int,
num_experts: int,
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_internode: bool = False,
): ):
uid = nvshmem_get_unique_id( try:
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() if use_internode:
torch.distributed.broadcast(uid, src=0) uid = nvshmem_get_unique_id(
nvshmem_init(uid, pgi.rank, pgi.world_size) ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
group_name = cpu_group.group_name
m, k = a.shape
e, _, n = w2.shape
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
device = torch.device("cuda", pgi.rank)
rank = pgi.rank
world_size = pgi.world_size
a = a.to(device)
w1 = w1.to(device)
w2 = w2.to(device)
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
if (quant_dtype is not None and not per_act_token_quant
and block_shape is None):
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
else:
a1_scale = None
a2_scale = None
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
batched_output = naive_batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
pplx_output = pplx_moe(
group_name,
rank,
world_size,
dp_size,
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
a2_scale=a2_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
chunked_batch_output = chunk_by_rank(
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(batched_output,
torch_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(pplx_output,
chunked_batch_output,
atol=3e-2,
rtol=3e-2)
finally:
if use_internode:
nvshmem_finalize()
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.optional
@requires_pplx
def test_pplx_moe_slow(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
use_internode: bool,
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
m, k = a.shape if dtype == torch.float8_e4m3fn:
e, _, n = w2.shape use_fp8_w8a8 = True
quant_dtype = dtype
else:
use_fp8_w8a8 = False
quant_dtype = None
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
with set_current_vllm_config(vllm_config), override_config(moe_config): if per_act_token_quant and block_shape is not None:
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) pytest.skip("Skip illegal quantization combination")
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
topk_weight, topk_ids)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
torch_output = chunk_by_rank(torch_output, pgi.rank, a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
pgi.world_size).to(pplx_output.device) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) _, w1, w1_s, _, w2, w2_s = make_test_weights(
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) e,
n,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
nvshmem_finalize() parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
use_internode)
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
make_weights: bool, test_fn: Callable):
def format_result(msg, ex=None):
if ex is not None:
x = str(ex)
newx = x.strip(" \n\t")[:16]
if len(newx) < len(x):
newx = newx + " ..."
prefix = "E\t"
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
print(f"FAILED {msg} - {newx}\n")
else:
print(f"PASSED {msg}")
current_platform.seed_everything(7)
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
[False, True], [None, [128, 128]])
exceptions = []
count = 0
for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
count = count + 1
m, n, k = mnk
if dtype == torch.float8_e4m3fn:
use_fp8_w8a8 = True
quant_dtype = dtype
else:
use_fp8_w8a8 = False
quant_dtype = None
test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}")
if not use_fp8_w8a8 and (per_act_token_quant
or block_shape is not None):
print(
f"{test_desc} - Skip quantization test for non-quantized type."
)
continue
if per_act_token_quant and block_shape is not None:
print(f"{test_desc} - Skip illegal quantization combination.")
continue
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
args = dict()
if make_weights:
_, w1, w1_s, _, w2, w2_s = make_test_weights(
e,
n,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
args["w1"] = w1
args["w2"] = w2
args["w1_s"] = w1_s
args["w2_s"] = w2_s
try:
test_fn(
pgi=pgi,
dp_size=dp_size,
a=a,
score=score,
topk=topk,
num_experts=e,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
use_internode=use_internode,
**args,
)
format_result(test_desc)
except Exception as ex:
format_result(test_desc, ex)
exceptions.append(ex)
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx @requires_pplx
def test_pplx_moe( def test_pplx_prepare_finalize(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_internode: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 use_internode, False, _pplx_prepare_finalize)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
def test_pplx_moe(
world_dp_size: tuple[int, int],
use_internode: bool,
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True,
_pplx_moe)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm)
from vllm.platforms import current_platform
# (E, T, H, group_size, seed)
CASES = [
(1, 1, 128, 64, 0),
(1, 4, 128, 128, 0),
(2, 4, 256, 128, 0),
(32, 64, 256, 128, 0),
(17, 31, 768, 128, 0),
]
@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
@torch.inference_mode()
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
current_platform.seed_everything(seed)
# Input tensor of shape (E, T, 2*H)
y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda")
tokens_per_expert = torch.randint(
low=0,
high=T,
size=(E, ),
dtype=torch.int32,
device="cuda",
)
# Run the Triton kernel
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
tokens_per_expert,
group_size=group_size,
eps=1e-10)
# Reference implementation
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
fp8_min = fp8_info.min
eps = 1e-10
# Compute silu activation and elementwise multiplication
y1 = y[..., :H]
y2 = y[..., H:]
silu_x = y1 * torch.sigmoid(y1)
merged = silu_x * y2
# Compute reference scales and quantized output, skipping padded tokens
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty((T, H // group_size),
dtype=torch.float32,
device="cuda")
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
for t in range(nt):
data = merged[e, t]
data_grp = data.view(H // group_size, group_size)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
scaled = data / scale.repeat_interleave(group_size)
clamped = scaled.clamp(fp8_min, fp8_max)
q = clamped.to(torch.float8_e4m3fn)
ref_s[t] = scale
ref_q[t] = q
y_se = y_s[e]
y_qe = y_q[e]
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
torch.testing.assert_close(
y_qe[:nt].to(torch.float32),
ref_q[:nt].to(torch.float32),
atol=2,
rtol=2e-1,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
per_block_cast_to_int8)
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up
def triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def naive_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
def make_quantized_test_activations(
E: int,
m: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
a_q = a
a_scale = None
if quant_dtype is not None:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
a_q = torch.zeros_like(a, dtype=quant_dtype)
a_scale_l = [None] * E
for e in range(E):
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
a[e], None, quant_dtype, per_act_token_quant, block_shape)
a_scale = torch.stack(a_scale_l)
if not per_act_token_quant and block_shape is None:
a_scale = a_scale.view(E, 1, 1)
return a, a_q, a_scale
def moe_quantize_weights(
w: torch.Tensor,
w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype],
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
if block_shape is not None:
assert not per_token_quant
if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape)
else:
w, w_s = per_block_cast_to_fp8(w, block_shape)
else:
if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
else:
w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
return w, w_s
def make_test_weight(
e: int,
rows: int,
cols: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
if quant_dtype is not None:
w_l = [None] * e
w_s_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
if w_s.ndim == 2:
assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1)
if block_shape is not None:
block_n, block_k = block_shape
n_tiles = (rows + block_n - 1) // block_n
k_tiles = (cols + block_k - 1) // block_k
assert w_s.shape == (e, n_tiles, k_tiles)
else:
w = w_16
w_s = None
return w_16, w, w_s
def make_test_weights(
e: int,
n: int,
k: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
)
...@@ -5,7 +5,10 @@ from typing import Optional, Union ...@@ -5,7 +5,10 @@ from typing import Optional, Union
import torch import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up
# Using the default value (240.0) from pytorch will cause accuracy # Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm. # issue on dynamic quantization models. Here use 224.0 for rocm.
...@@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ...@@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
return ref_out, ref_scale.view((1, )) return ref_out, ref_scale.view((1, ))
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, def native_w8a8_block_matmul(
As: torch.Tensor, Bs: torch.Tensor, block_size, A: torch.Tensor,
output_dtype): B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
compute_type: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise """This function performs matrix multiplication with block-wise
quantization using native torch. quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and It is agnostic to the input data type and can be used for both int8 and
...@@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, ...@@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
`Bs` (float32). `Bs` (float32).
The output is returned in the specified `output_dtype`. The output is returned in the specified `output_dtype`.
""" """
A = A.to(torch.float32) A = A.to(compute_type)
B = B.to(torch.float32) B = B.to(compute_type)
assert A.shape[-1] == B.shape[-1] assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2 assert len(block_size) == 2
...@@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, ...@@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
As = As.reshape(M, As.shape[-1]) As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0] assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}"
assert k_tiles == Bs.shape[1] assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}"
C_shape = (M, N) C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
A_tiles = [ A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
...@@ -152,3 +161,170 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, ...@@ -152,3 +161,170 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
C = C.reshape(origin_C_shape).to(output_dtype) C = C.reshape(origin_C_shape).to(output_dtype)
return C return C
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must "
"be divisible by `group_size`")
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` must be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_min = iinfo.min
int8_max = iinfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
DEFAULT_BLOCK_SHAPE = [128, 128]
def per_block_cast_to_fp8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def per_block_cast_to_int8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
) -> torch.Tensor:
if scale is not None:
f32 = torch.float32
if per_act_token_quant or block_shape is None:
return (t.to(f32) * scale).to(out_dtype)
else:
return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype)
else:
return t.to(out_dtype)
def batched_dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
) -> torch.Tensor:
if scale is not None:
assert t.shape[0] == scale.shape[0]
out = torch.empty_like(t, dtype=out_dtype)
for e in range(t.shape[0]):
out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant,
out_dtype)
return out
return t.to(out_dtype)
def native_batched_masked_quant_matmul(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
num_expert_tokens: torch.Tensor,
A_scale: Optional[torch.Tensor] = None,
B_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
if A.dtype.itemsize == 1 and block_shape is not None:
assert A_scale is not None and B_scale is not None
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
block_shape, C.dtype)
C[e, :num_tokens, :] = tmp[:num_tokens, :]
elif A.dtype.itemsize == 1 and block_shape is None:
assert A_scale is not None and B_scale is not None
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
C[e, :num_tokens, :] = (
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
else:
assert A_scale is None
assert B_scale is None
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
...@@ -7,15 +7,10 @@ import itertools ...@@ -7,15 +7,10 @@ import itertools
import pytest import pytest
import torch import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
from vllm.config import VllmConfig, set_current_vllm_config native_w8a8_block_matmul,
from vllm.model_executor.layers.activation import SiluAndMul per_block_cast_to_fp8)
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.config import VllmConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -45,78 +40,10 @@ N = [128, 512, 7168, 7748, 13824] ...@@ -45,78 +40,10 @@ N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384] K = [256, 3884, 4096, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168. # and its hidden size is 7168.
M_moe = [1, 2, 7, 83, 128, 2048]
M_moe_dg = [128, 192, 1335, 2048]
N_moe = [128, 256, 1024, 4608] # [13824]
K_moe = [256, 512, 7168] # [13824]
BLOCK_SIZE = [[128, 128]] BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16, 24] # [128, 256]
TOP_KS = [1, 2, 6]
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0] SEEDS = [0]
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot "
"be divisible by `group_size`")
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
# Skip all tests if CUDA is not available # Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda") pytest.importorskip("torch.cuda")
...@@ -176,89 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -176,89 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001 assert rel_diff < 0.001
@pytest.mark.parametrize(
"M,N,K,E,topk,block_size,dtype,seed",
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES,
SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = (torch.rand(
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
del w1_bf16
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
del w2_bf16
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = torch.rand(
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
w2_s = torch.rand(
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed", "M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
...@@ -301,152 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -301,152 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32)))) torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001 assert rel_diff < 0.001
def fp8_perm(m, idx):
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
M, K = a.shape
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
num_tokens = topk * M
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
a = fp8_perm(a, sorted_token_ids // topk)
if a_s is not None:
a_s = a_s[sorted_token_ids // topk]
return a, a_s, m_indices, inv_perm
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
M = topk_weight.shape[0]
out = out[inv_perm, ...]
tmp_out = out.view(-1, topk, K)
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
block_shape):
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
num_groups = w1.shape[0]
M, K = a.shape
N = w2.shape[-1]
topk_weight, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = per_token_group_quant_fp8(a, block_m)
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
num_groups, topk, block_m)
inter_out = torch.zeros((a_q.shape[0], N * 2),
dtype=torch.bfloat16,
device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
return final_out
@pytest.mark.parametrize(
"M,N,K,E,topk,seed",
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
score = torch.randn((M, E), dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w2 = (N + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(E):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
if M >= 128:
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
score, topk, block_size)
else:
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
...@@ -8,9 +8,7 @@ import pytest ...@@ -8,9 +8,7 @@ import pytest
import torch import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import (
w8a8_block_int8_matmul) w8a8_block_int8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -23,82 +21,10 @@ vllm_config = VllmConfig() ...@@ -23,82 +21,10 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192 vllm_config.scheduler_config.max_model_len = 8192
# For test
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_min = iinfo.min
int8_max = iinfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
# For test
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using
native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
DTYPES = [torch.half, torch.bfloat16] DTYPES = [torch.half, torch.bfloat16]
M = [1, 33, 64, 222] M = [1, 33, 64, 222]
N = [128, 1024] N = [128, 1024]
K = [256, 4096] K = [256, 4096]
E = [8, 24]
TOP_KS = [2, 6]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]] BLOCK_SIZE = [[128, 128]]
SEEDS = [0] SEEDS = [0]
...@@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32)))) torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001 assert rel_diff < 0.001
@pytest.mark.parametrize(
"M, N, K, E, topk, block_size, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
@torch.inference_mode()
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch.manual_seed(seed)
# Use a smaller factor for scale initialization to prevent large
# values/overflow especially when output dtype might be float16
factor_for_scale = 1e-2
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand(
(E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (torch.rand(
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
w2_s = (torch.rand(
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
# Check results
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.06
...@@ -11,6 +11,7 @@ from vllm.platforms import current_platform ...@@ -11,6 +11,7 @@ from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
SCALE = [0.1, 2.1] SCALE = [0.1, 2.1]
......
...@@ -14,6 +14,8 @@ import torch ...@@ -14,6 +14,8 @@ import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
query_machete_supported_group_sizes)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights) pack_rows, quantize_weights)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -46,8 +48,6 @@ MNK_SHAPES = [ ...@@ -46,8 +48,6 @@ MNK_SHAPES = [
(1024, 8192, 4096), (1024, 8192, 4096),
] ]
GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1]
@dataclass @dataclass
class TypeConfig: class TypeConfig:
...@@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): ...@@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
def group_size_valid(shape: tuple[int, int, int], def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool: group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or group_size % shape[2] == 0 return group_size is None or group_size == -1 or shape[2] % group_size == 0
def machete_quantize_and_pack(atype: torch.dtype, def machete_quantize_and_pack(atype: torch.dtype,
...@@ -270,7 +270,7 @@ def test_machete_all_schedules(shape, types: TypeConfig): ...@@ -270,7 +270,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
if types.group_scale_type is None: if types.group_scale_type is None:
group_sizes = [None] group_sizes = [None]
else: else:
group_sizes = GROUP_SIZES_TO_TEST group_sizes = query_machete_supported_group_sizes(types.act_type)
for group_size in group_sizes: for group_size in group_sizes:
if not group_size_valid(shape, group_size): if not group_size_valid(shape, group_size):
...@@ -299,7 +299,7 @@ def test_machete_heuristic(shape, types: TypeConfig): ...@@ -299,7 +299,7 @@ def test_machete_heuristic(shape, types: TypeConfig):
if types.group_scale_type is None: if types.group_scale_type is None:
group_sizes = [None] group_sizes = [None]
else: else:
group_sizes = GROUP_SIZES_TO_TEST group_sizes = query_machete_supported_group_sizes(types.act_type)
for group_size in group_sizes: for group_size in group_sizes:
if not group_size_valid(shape, group_size): if not group_size_valid(shape, group_size):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch import torch
...@@ -74,3 +75,51 @@ def test_apply_repetition_penalties( ...@@ -74,3 +75,51 @@ def test_apply_repetition_penalties(
# Test the operator by applying the opcheck utility # Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_, opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties)) (logits.clone(), prompt_mask, output_mask, repetition_penalties))
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test for checking CUDA kernel")
@torch.inference_mode()
def test_apply_repetition_penalties_zero_seqs() -> None:
"""
Test the apply_repetition_penalties custom op with num_seqs=0
against a reference implementation.
"""
num_seqs = 0
vocab_size = 17
repetition_penalty = 1.05
dtype = torch.float32
seed = 0
current_platform.seed_everything(seed)
torch.set_default_device("cuda:0")
# Create test data
logits = torch.randn(num_seqs, vocab_size, dtype=dtype)
# Create masks with some random tokens marked as repeated
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
# No tokens to mark as repeated since num_seqs=0
# Create repetition penalties tensor
repetition_penalties = torch.full((num_seqs, ),
repetition_penalty,
dtype=dtype)
# Run all three implementations
logits_torch = logits.clone()
logits_cuda = logits.clone()
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
repetition_penalties)
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
repetition_penalties)
# Compare all outputs to reference
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)
# Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend""" """Integration tests for FlexAttention backend vs default backend"""
import random import random
...@@ -51,7 +52,6 @@ def test_flex_attention_vs_default_backend(monkeypatch): ...@@ -51,7 +52,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
set_seed(seed) set_seed(seed)
...@@ -66,7 +66,6 @@ def test_flex_attention_vs_default_backend(monkeypatch): ...@@ -66,7 +66,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
# Run with default backend # Run with default backend
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
set_seed(seed) set_seed(seed)
llm_default = LLM( llm_default = LLM(
model_name, model_name,
......
...@@ -13,8 +13,11 @@ import pytest ...@@ -13,8 +13,11 @@ import pytest
import torch import torch
from torch._prims_common import TensorLikeType from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.platforms.interface import _Backend from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
...@@ -1054,23 +1057,92 @@ def compute_max_diff(output, output_ref): ...@@ -1054,23 +1057,92 @@ def compute_max_diff(output, output_ref):
torch.abs(output_ref)) torch.abs(output_ref))
def torch_moe(a, w1, w2, score, topk, expert_map): def torch_experts(
B, D = a.shape a: torch.Tensor,
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) w1: torch.Tensor,
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) w2: torch.Tensor,
score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight: torch.Tensor,
topk_weight, topk_ids = torch.topk(score, topk) topk_ids: torch.Tensor,
topk_weight = topk_weight.view(-1) global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
assert (global_num_experts == -1
or (global_num_experts == w1.shape[0] and expert_map is None)
or (expert_map is not None
and global_num_experts == expert_map.shape[0]))
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype,
per_act_token_quant, block_shape)
num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1) topk_ids = topk_ids.view(-1)
if expert_map is not None: if expert_map is not None:
topk_ids = expert_map[topk_ids] topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
f32 = torch.float32
for i in range(num_experts):
mask = topk_ids == i mask = topk_ids == i
if mask.sum(): if mask.sum():
out[mask] = SiluAndMul()( if quant_dtype is None:
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) tmp1 = a[mask] @ w1[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) * tmp2 = SiluAndMul()(tmp1)
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) out[mask] = tmp2 @ w2[i].transpose(0, 1)
elif block_shape is not None:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
out.dtype)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = moe_kernel_quantize_input(
tmp2, a2_scale, quant_dtype, per_act_token_quant,
block_shape)
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
out.dtype)
else:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
tmp1 = a[mask].to(f32) * scales
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
tmp1 = tmp1 @ w1_dq
tmp2 = SiluAndMul()(tmp1)
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
return (out.view(M, -1, w2.shape[1]).to(f32) *
topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype)
def torch_moe(a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
expert_map)
def torch_moe_single(a, w, score, topk): def torch_moe_single(a, w, score, topk):
......
...@@ -249,23 +249,6 @@ def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): ...@@ -249,23 +249,6 @@ def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
model_runner.model) model_runner.model)
@pytest.fixture(params=[True, False])
def run_with_both_engines_lora(request, monkeypatch):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")
if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
monkeypatch.setenv('VLLM_USE_V1', '1')
else:
monkeypatch.setenv('VLLM_USE_V1', '0')
yield
@pytest.fixture @pytest.fixture
def reset_default_device(): def reset_default_device():
""" """
......
...@@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation): ...@@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env, torch_level, ops_enabled, default_on", "env, torch_level, use_inductor, ops_enabled, default_on",
[ [
# Default values based on compile level # Default values based on compile level
("", 0, [True] * 4, True), # - All by default (no Inductor compilation)
("", 1, [True] * 4, True), ("", 0, False, [True] * 4, True),
("", 2, [True] * 4, True), # All by default ("", 1, True, [True] * 4, True),
("", 3, [False] * 4, False), ("", 2, False, [True] * 4, True),
("", 4, [False] * 4, False), # None by default # - None by default (with Inductor)
("", 3, True, [False] * 4, False),
("", 4, True, [False] * 4, False),
# - All by default (without Inductor)
("", 3, False, [True] * 4, True),
("", 4, False, [True] * 4, True),
# Explicitly enabling/disabling # Explicitly enabling/disabling
# #
# Default: all # Default: all
# #
# All but SiluAndMul # All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True), ("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
# Only ReLU3 # Only ReLU3
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False), ("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
# All but SiluAndMul # All but SiluAndMul
("all,-silu_and_mul", 1, [1, 0, 1, 1], True), ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on) # All but ReLU3 (even if ReLU2 is on)
("-relu3,relu2", 1, [1, 1, 1, 0], True), ("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
# GeluAndMul and SiluAndMul # RMSNorm and SiluAndMul
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False), ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
# All but RMSNorm # All but RMSNorm
("-rms_norm", 2, [0, 1, 1, 1], True), ("-rms_norm", 3, False, [0, 1, 1, 1], True),
# #
# Default: none # Default: none
# #
# Only ReLU3 # Only ReLU3
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False), ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
# All but RMSNorm # All but RMSNorm
("all,-rms_norm", 4, [0, 1, 1, 1], True), ("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
]) ])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: list[int], def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
default_on: bool): ops_enabled: list[int], default_on: bool):
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=torch_level, custom_ops=env.split(","))) compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
level=torch_level,
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on assert CustomOp.default_on() == default_on
......
...@@ -29,8 +29,8 @@ def test_model_loading_with_params(vllm_runner): ...@@ -29,8 +29,8 @@ def test_model_loading_with_params(vllm_runner):
revision=REVISION, revision=REVISION,
dtype="float16", dtype="float16",
max_model_len=MAX_MODEL_LEN) as vllm_model: max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that" output = vllm_model.embed("Write a short story about a robot that"
" dreams for the first time.\n") " dreams for the first time.\n")
model_config = vllm_model.model.llm_engine.model_config model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer model_tokenizer = vllm_model.model.llm_engine.tokenizer
...@@ -67,8 +67,8 @@ def test_roberta_model_loading_with_params(vllm_runner): ...@@ -67,8 +67,8 @@ def test_roberta_model_loading_with_params(vllm_runner):
revision=REVISION_ROBERTA, revision=REVISION_ROBERTA,
dtype="float16", dtype="float16",
max_model_len=MAX_MODEL_LEN) as vllm_model: max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that" output = vllm_model.embed("Write a short story about a robot that"
" dreams for the first time.\n") " dreams for the first time.\n")
model_config = vllm_model.model.llm_engine.model_config model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer model_tokenizer = vllm_model.model.llm_engine.tokenizer
...@@ -105,8 +105,8 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner): ...@@ -105,8 +105,8 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
with vllm_runner(model_name=model_name, with vllm_runner(model_name=model_name,
dtype="float16", dtype="float16",
max_model_len=MAX_MODEL_LEN) as vllm_model: max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that" output = vllm_model.embed("Write a short story about a robot that"
" dreams for the first time.\n") " dreams for the first time.\n")
model_tokenizer = vllm_model.model.llm_engine.tokenizer model_tokenizer = vllm_model.model.llm_engine.tokenizer
assert model_tokenizer.tokenizer_id == model_name assert model_tokenizer.tokenizer_id == model_name
......
...@@ -118,7 +118,7 @@ def run_test( ...@@ -118,7 +118,7 @@ def run_test(
# default to enforce_eager=True if enforce_eager # default to enforce_eager=True if enforce_eager
# is left unspecified. However, the # is left unspecified. However, the
# VllmRunner test fixture (which wraps around the LLM class) defaults to # VllmRunner test fixture (which wraps around the LLM class) defaults to
# enforce_eager=False (a behavior which a number of already-exisitng # enforce_eager=False (a behavior which a number of already-existing
# decoder-only unit tests expect), so when testing an encoder/decoder # decoder-only unit tests expect), so when testing an encoder/decoder
# model we must explicitly specify enforce_eager=True in the VllmRunner # model we must explicitly specify enforce_eager=True in the VllmRunner
# constructor. # constructor.
......
...@@ -78,7 +78,7 @@ AITER_MODEL_LIST = [ ...@@ -78,7 +78,7 @@ AITER_MODEL_LIST = [
), ),
pytest.param( pytest.param(
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2 "Qwen/Qwen2.5-0.5B-Instruct", # qwen2
marks=[pytest.mark.core_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
pytest.param( pytest.param(
"Qwen/Qwen3-8B", # qwen (text-only) "Qwen/Qwen3-8B", # qwen (text-only)
...@@ -87,6 +87,7 @@ AITER_MODEL_LIST = [ ...@@ -87,6 +87,7 @@ AITER_MODEL_LIST = [
pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param( pytest.param(
"TitanML/tiny-mixtral", # mixtral "TitanML/tiny-mixtral", # mixtral
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
) )
]) ])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
MODELS = ["google/gemma-2b", "google/gemma-2-2b", "google/gemma-3-4b-it"]
@pytest.mark.parametrize("model", MODELS)
def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None:
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(
model,
load_format="dummy",
) as llm:
if model == "google/gemma-3-4b-it":
normalizers = llm.model.collective_rpc(
lambda self: self.model_runner.model.language_model.model.
normalizer.cpu().item())
config = llm.model.llm_engine.model_config.hf_config.text_config
else:
normalizers = llm.model.collective_rpc(
lambda self: self.model_runner.model.model.normalizer.cpu(
).item())
config = llm.model.llm_engine.model_config.hf_config
assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment