Commit 7a985548 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.0' into v0.9.0-ori

parents 45d3785c dc1440cf
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1536),
(224, 1024, 1024),
(224, 1024, 1536),
]
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
sf_w1_2n = round_up(2 * n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_q = torch.empty((e, 2 * n, k // 2),
device="cuda",
dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1[expert], w1_gs[expert])
w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2[expert], w2_gs[expert])
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
cutlass_output = cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_q,
w1_blockscale=w1_blockscale,
w1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
w2_alphas=(1 / w2_gs),
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=e,
device=a.device,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=w1.dtype,
device=w1.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=w2.dtype,
device=w2.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
torch.testing.assert_close(torch_output,
cutlass_output,
atol=1e-1,
rtol=1e-1)
if __name__ == "__main__":
test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import dataclasses
import os
import traceback
from typing import Callable, Optional
import pytest
import torch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
has_pplx = True
except ImportError:
has_pplx = False
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
get_default_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)]
PPLX_MOE_COMBOS = [
(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),
(64, 1024, 1024),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
P = ParamSpec("P")
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
def torch_prepare(
a: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
max_num_tokens: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert topk_ids.dim() == 2
assert topk_ids.shape[0] == a.shape[0]
num_tokens, hidden_dim = a.shape
topk = topk_ids.shape[1]
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
assert tokens_per_expert.numel() == num_experts
if max_num_tokens is None:
max_num_tokens = int(tokens_per_expert.max().item())
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
dtype=a.dtype,
device=a.device)
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
for token in range(num_tokens):
for j in range(topk):
expert_id = topk_ids[token, j]
idx = token_counts[expert_id]
b_a[expert_id, idx:idx + 1, :] = a[token, :]
token_counts[expert_id] = token_counts[expert_id] + 1
return b_a, tokens_per_expert
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
num_tokens = topk_ids.shape[0]
num_experts = b_out.shape[0]
K = b_out.shape[-1]
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
expert_counts = torch.zeros(num_experts,
dtype=torch.int,
device=b_out.device)
for token in range(num_tokens):
expert_ids = topk_ids[token]
for i in range(expert_ids.numel()):
expert_id = expert_ids[i]
idx = expert_counts[expert_id]
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
1, :] * topk_weight[token, i]
expert_counts[expert_id] = expert_counts[expert_id] + 1
return out
def torch_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
assert b_a.dim() == 3
num_tokens, topk = topk_ids.shape
_, max_num_tokens, K = b_a.shape
assert num_experts == b_a.shape[0] and w2.shape[1] == K
out = torch.zeros((num_experts, max_num_tokens, K),
dtype=b_a.dtype,
device=b_a.device)
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
dtype=b_a.dtype,
device=b_a.device)
for expert in range(num_experts):
num = tokens_per_expert[expert]
if num > 0:
torch.ops._C.silu_and_mul(
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
return torch_finalize(out, topk_weight, topk_ids)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
current_platform.seed_everything(7)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
torch.testing.assert_close(baseline_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_output,
batched_output,
atol=2e-2,
rtol=0)
def rank_chunk(num: int, r: int, w: int) -> int:
rem = num % w
return (num // w) + (1 if r < rem else 0)
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
assert torch.cuda.current_device() == pgi.local_rank
topk = topk_ids.shape[1]
num_tokens, hidden_dim = a.shape
block_size = 128
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
)
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
world_size,
rank,
dp_size,
a.dtype,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
a_chunk,
None,
None,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
)
b_a = b_a * 1.5
out = torch.full(
(max_num_tokens, hidden_dim),
torch.nan,
dtype=a.dtype,
device=device,
)
prepare_finalize.finalize(
out,
b_a,
chunk_topk_weight,
chunk_topk_ids,
False,
)
torch.cuda.synchronize()
ata.destroy()
num_tokens = a_chunk.shape[0]
return out[:num_tokens]
def _pplx_prepare_finalize(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
score: torch.Tensor,
topk: torch.Tensor,
num_experts: int,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
device = pgi.device
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
k = a.shape[1]
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
torch_output = (a_rep.view(-1, topk, k) * 1.5 *
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
a.dtype)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
num_experts)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize()
# TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@requires_pplx
def test_pplx_prepare_finalize(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
device = "cuda"
a = torch.randn((m, k), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e)
def pplx_moe(
rank: int,
world_size: int,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
use_compile: bool = True,
use_cudagraphs: bool = True,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
device = torch.device("cuda", rank)
hidden_dim = a.shape[1]
num_experts = w1.shape[0]
block_size = 128
topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
)
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
world_size,
rank,
dp_size,
)
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
world_size=world_size,
dp_size=dp_size)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
# Chunking weights like this only works for batched format
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
else:
_fused_experts = fused_experts
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
if use_cudagraphs:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
ata.destroy()
return out
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
assert torch.cuda.current_device() == pgi.local_rank
num_experts = w1.shape[0]
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
prepare_finalize = BatchedPrepareAndFinalize(
max_num_tokens=max_num_tokens,
world_size=world_size,
dp_size=dp_size,
rank=rank,
)
experts = BatchedExperts(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
out = fused_experts(
a_chunk,
# Chunking weights like this only works for batched format
chunk_by_rank(w1, rank, world_size).to(device),
chunk_by_rank(w2, rank, world_size).to(device),
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
return out
def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
m, k = a.shape
e, _, n = w2.shape
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
topk_weight, topk_ids)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize()
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@requires_pplx
def test_pplx_moe(
mnk: tuple[int, int, int],
e: int,
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
# SPDX-License-Identifier: Apache-2.0
# This is a test for the AITER ops.
# It tests if the AITER ops are
# 1. correctly registered as custom ops
# 2. correctly defined the relationship between
# implementation and fake function
# 3. can be used with torch.compile
# This file will be skipped if AITER is not installed
# and the platform is not ROCm.
import importlib.util
import pytest
import torch
# this import statement is needed to ensure the ops are registered
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401
from vllm.platforms import current_platform
# need to import once to ensure the ops are registered
# Check if aiter package is installed
aiter_available = importlib.util.find_spec("aiter") is not None
pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and aiter_available),
reason="AITER ops are only available on ROCm with aiter package installed")
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')
# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
"""Test that the op can be used with torch.compile."""
# Create test tensors
token = 64
expert = 256
num_expert_group = 8
topk = 8
topk_group = 4
renormalize = True
scale_factor = 1.0
gating_output = torch.randn((token, expert),
dtype=torch.bfloat16,
device="cuda")
e_score_correction_bias = torch.randn((expert, ),
dtype=torch.bfloat16,
device="cuda")
device = gating_output.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)
# Define a function that uses the op
def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights, topk_ids):
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output, e_score_correction_bias, topk_weights, topk_ids,
num_expert_group, topk_group, renormalize, scale_factor)
# Verify the op's fake implementation
torch.library.opcheck(
torch.ops.vllm.rocm_aiter_biased_grouped_topk,
(gating_output, e_score_correction_bias, topk_weights, topk_ids),
kwargs={
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"routed_scaling_factor": scale_factor
},
test_utils=("test_faketensor"))
# Compile the function with appropriate settings
compiled_fn = torch.compile(biased_grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
topk_weights_original = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_original = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_compiled = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_compiled = torch.empty((token, topk),
dtype=torch.int32,
device=device)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights_original, topk_ids_original)
compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
topk_ids_compiled)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original, indices_original = torch.sort(topk_ids_original)
topk_weights_original = torch.gather(topk_weights_original, 1,
indices_original)
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
indices_compiled)
# Verify results match
assert torch.allclose(topk_weights_original,
topk_weights_compiled,
rtol=1e-2,
atol=1e-2)
assert torch.allclose(topk_ids_original, topk_ids_compiled)
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0): ...@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True) allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input """Matrix multiplication function that supports per-token input
...@@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): ...@@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) with set_current_vllm_config(vllm_config):
out = fused_moe( ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
a, out = fused_moe(
w1, a,
w2, w1,
score, w2,
topk, score,
renormalize=False, topk,
use_fp8_w8a8=True, # using fp8 renormalize=False,
per_channel_quant=True, use_fp8_w8a8=True, # using fp8
w1_scale=w1_s, per_channel_quant=True,
w2_scale=w2_s, w1_scale=w1_s,
block_shape=None, # Not using block quantization w2_scale=w2_s,
) block_shape=None, # Not using block quantization
)
# Check results # Check results
rel_diff = (torch.mean( rel_diff = (torch.mean(
......
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.scalar_type import scalar_types
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype=dtype)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)
# SPDX-License-Identifier: Apache-2.0
"""Test AWQ with fused MoE Marlin kernels.
Run `pytest tests/kernels/test_awq_marlin.py`.
"""
import pytest
import torch
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
from vllm.scalar_type import scalar_types
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
GROUP_SIZES = [-1, 32, 128]
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("group_size", GROUP_SIZES)
@pytest.mark.skipif(not (ops.supports_moe_ops
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
reason="Marlin is not supported on this GPU type.")
def test_fused_marlin_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w_ref1_l = []
qweights1_l = []
scales1_l = []
zp1_l = []
for i in range(w1.shape[0]):
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1)
qweights1_l.append(qweight1)
scales1_l.append(scales1)
zp1_l.append(zp1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweights1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
zp1 = stack_and_dev(zp1_l)
w_ref2_l = []
qweights2_l = []
scales2_l = []
zp2_l = []
for i in range(w2.shape[0]):
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2)
qweights2_l.append(qweight2)
scales2_l.append(scales2)
zp2_l.append(zp2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweights2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
zp2 = stack_and_dev(zp2_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
w1_zeros=zp1,
w2_zeros=zp2,
num_bits=num_bits,
)
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
score, topk, None)
assert compute_max_diff(marlin_output, torch_output) < 4e-2
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
scales_l = []
zp_l = []
for i in range(w.shape[0]):
w_ref, qweight, scales, zp = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
zp_l.append(zp)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l).contiguous()
zp = stack_and_dev(zp_l).contiguous()
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = torch.ops.vllm.single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
...@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config ...@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
deep_gemm_moe_fp8) _valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size) moe_align_block_size)
...@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0): ...@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True) allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations # Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048] NUM_TOKENS = [7, 83, 2048]
...@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
out = fused_moe( out = fused_moe(
a, a,
...@@ -258,6 +261,7 @@ def per_block_cast_to_fp8( ...@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed", "M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes # only aligned sizes
...@@ -338,7 +342,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, ...@@ -338,7 +342,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
M, K = a.shape M, K = a.shape
N = w2.shape[-1] N = w2.shape[-1]
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) topk_weight, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
block_m = deep_gemm.get_m_alignment_for_contiguous_layout() block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
...@@ -380,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): ...@@ -380,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_size = [block_m, block_m] block_size = [block_m, block_m]
dtype = torch.bfloat16 dtype = torch.bfloat16
# only aligned sizes if topk > E:
if (N % block_m != 0 or K % block_m != 0 or topk > E): pytest.skip(f"Skipping test: topk={topk} > E={E}")
pytest.skip(
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
if N <= 512:
pytest.skip("Skipping N <= 512 until performance issues solved.")
vllm_config = VllmConfig() if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
torch.manual_seed(seed) torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_info = torch.finfo(torch.float8_e4m3fn)
...@@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): ...@@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size) topk, block_size)
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
......
...@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0): ...@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True) allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# For test # For test
def native_per_token_group_quant_int8(x, def native_per_token_group_quant_int8(x,
...@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
out = fused_moe( out = fused_moe(
a, a,
......
...@@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int, ...@@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1)
opcheck(torch.ops._C.cutlass_scaled_mm, opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias)) (out, a, b, scale_a, scale_b, bias))
...@@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, ...@@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
return return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return return
if m % 4 != 0 and current_platform.has_device_capability(100):
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias) use_bias)
......
...@@ -36,3 +36,9 @@ def test_ggml_opcheck(quant_type): ...@@ -36,3 +36,9 @@ def test_ggml_opcheck(quant_type):
opcheck(torch.ops._C.ggml_moe_a8, opcheck(torch.ops._C.ggml_moe_a8,
(x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded, (x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded,
quant_type, qweight.shape[0], 1, x.shape[0])) quant_type, qweight.shape[0], 1, x.shape[0]))
topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32)
opcheck(
torch.ops._C.ggml_moe_a8_vec,
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]))
...@@ -151,20 +151,7 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, ...@@ -151,20 +151,7 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
@pytest.mark.parametrize("hidden_size", [512]) @pytest.mark.parametrize("hidden_size", [512])
@pytest.mark.parametrize("top_k", [4, 8]) @pytest.mark.parametrize("top_k", [4, 8])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize( @pytest.mark.parametrize("quant_type", QUANT_TYPES)
"quant_type",
[
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quants
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
])
@torch.inference_mode() @torch.inference_mode()
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType, top_k: int): quant_type: GGMLQuantizationType, top_k: int):
...@@ -174,7 +161,10 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, ...@@ -174,7 +161,10 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda") x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype) topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda") topk_ids = torch.randint(0,
E, (num_tokens, top_k),
device="cuda",
dtype=torch.int32)
tensors = get_gguf_MoE_tensors(hidden_size, quant_type) tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
......
...@@ -18,9 +18,12 @@ from vllm.model_executor.layers.quantization.qqq import ( ...@@ -18,9 +18,12 @@ from vllm.model_executor.layers.quantization.qqq import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_permute_scales, query_marlin_supported_quant_types) marlin_make_workspace_new, marlin_permute_scales,
query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32) marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize, MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
marlin_weights) marlin_weights)
...@@ -73,7 +76,7 @@ def rand_data(shape, dtype=torch.float16): ...@@ -73,7 +76,7 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", @pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False)) query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
...@@ -138,7 +141,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, ...@@ -138,7 +141,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", @pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False)) query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
...@@ -189,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, ...@@ -189,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
query_marlin_supported_quant_types(False)) @pytest.mark.parametrize(
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) "group_size",
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
...@@ -209,6 +213,7 @@ def test_gptq_marlin_gemm( ...@@ -209,6 +213,7 @@ def test_gptq_marlin_gemm(
use_fp32_reduce, use_fp32_reduce,
): ):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor size_m = m_factor
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
...@@ -219,39 +224,74 @@ def test_gptq_marlin_gemm( ...@@ -219,39 +224,74 @@ def test_gptq_marlin_gemm(
return return
if group_size == size_k: if group_size == size_k:
return return
if has_zp:
return
if size_k % group_size != 0:
return
a_input = rand_data((size_m, size_k)) a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( if quant_type == scalar_types.float4_e2m1f:
b_weight, quant_type, group_size, act_order) if group_size != 16 or act_order:
return
marlin_zp = marlin_make_empty_g_idx(marlin_s.device) w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
b_weight.T, group_size)
g_idx = None
sort_indices = None
marlin_zp = None
elif quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
b_weight.T, group_size)
g_idx = None
sort_indices = None
marlin_zp = None
marlin_s2 = None
elif has_zp:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order)
marlin_zp = None
marlin_s2 = None
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, workspace = marlin_make_workspace_new(w_ref.device)
GPTQ_MARLIN_MAX_PARALLEL)
opcheck(torch.ops._C.gptq_marlin_gemm, opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
workspace.scratch, quant_type.id, a_input.shape[0], sort_indices, workspace, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, False, b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
use_atomic_add, use_fp32_reduce, False), use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS) test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a_input, a_input,
None,
marlin_q_w, marlin_q_w,
marlin_s, marlin_s,
marlin_s2,
marlin_zp, marlin_zp,
g_idx, g_idx,
sort_indices, sort_indices,
workspace.scratch, workspace,
quant_type, quant_type,
a_input.shape[0], a_input.shape[0],
b_weight.shape[1], b_weight.shape[1],
a_input.shape[1], a_input.shape[1],
is_k_full=is_k_full, is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add, use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce, use_fp32_reduce=use_fp32_reduce,
is_zp_float=False, is_zp_float=False,
...@@ -326,143 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, ...@@ -326,143 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert max_diff < 0.04 assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", [8])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_fp8_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
# WEIGHTS
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device="cuda"),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=size_k,
size_n=size_n,
group_size=-1)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
opcheck(torch.ops._C.fp8_marlin_gemm,
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
output = ops.fp8_marlin_gemm(
a=a_input,
b_q_weight=marlin_qweight,
b_scales=marlin_scales,
workspace=workspace.scratch,
num_bits=num_bits,
size_m=a_input.shape[0],
size_n=b_weight.shape[1],
size_k=a_input.shape[1],
)
output_ref = torch.matmul(a_input, b_weight)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_awq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size)
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
is_k_full = True
has_zp = True
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.gptq_marlin_gemm(
a_input,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
...@@ -508,23 +411,23 @@ def test_hqq_marlin_gemm( ...@@ -508,23 +411,23 @@ def test_hqq_marlin_gemm(
g_idx = marlin_make_empty_g_idx(dev) g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev) g_idx_sort_indices = marlin_make_empty_g_idx(dev)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, workspace = marlin_make_workspace_new(b_weight.device)
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a_input, a_input,
None,
marlin_w_q, marlin_w_q,
marlin_s, marlin_s,
None,
marlin_zp, marlin_zp,
g_idx, g_idx,
g_idx_sort_indices, g_idx_sort_indices,
workspace.scratch, workspace,
quant_type, quant_type,
a_input.shape[0], a_input.shape[0],
b_weight.shape[0], b_weight.shape[0],
a_input.shape[1], a_input.shape[1],
is_k_full=True, is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce, use_fp32_reduce=use_fp32_reduce,
is_zp_float=True, is_zp_float=True,
) )
...@@ -621,23 +524,23 @@ def test_marlin_gemm_subset_input(): ...@@ -621,23 +524,23 @@ def test_marlin_gemm_subset_input():
b_weight, quant_type, group_size, False) b_weight, quant_type, group_size, False)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device) marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, workspace = marlin_make_workspace_new(a_input.device)
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a_input, a_input,
None,
marlin_q_w, marlin_q_w,
marlin_s, marlin_s,
None,
marlin_zp, marlin_zp,
g_idx, g_idx,
sort_indices, sort_indices,
workspace.scratch, workspace,
quant_type, quant_type,
a_input.shape[0], a_input.shape[0],
b_weight.shape[1], b_weight.shape[1],
a_input.shape[1], a_input.shape[1],
is_k_full=True, is_k_full=True,
has_zp=False,
use_atomic_add=False, use_atomic_add=False,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
......
...@@ -17,7 +17,7 @@ PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48), ...@@ -17,7 +17,7 @@ PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
SEEDS = [42] SEEDS = [42]
CUDA_DEVICES = ['cuda:0'] CUDA_DEVICES = ['cuda:0']
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
# E2M1 to float # E2M1 to float
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest import pytest
import torch import torch
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if not current_platform.has_device_capability(100): if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
...@@ -19,95 +20,24 @@ SHAPES.extend(PAD_SHAPES) ...@@ -19,95 +20,24 @@ SHAPES.extend(PAD_SHAPES)
SEEDS = [42] SEEDS = [42]
CUDA_DEVICES = ['cuda:0'] CUDA_DEVICES = ['cuda:0']
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloatArray = [
0.,
0.5,
1.,
1.5,
2.,
3.,
4.,
6.,
]
def e2m1_to_fp32(int4_value):
signBit = (int4_value & 0x8)
int4_absValue = int4_value & 0x7
float_result = kE2M1ToFloatArray[int4_absValue]
if (signBit):
float_result = -float_result
return float_result
def break_fp4_bytes(a, dtype):
assert (a.dtype == torch.uint8)
m, n = a.shape
a = a.flatten()
# Get upper 4 bits
highHalfByte = (a & 0xF0) >> 4
# Get lower 4 bits
lowHalfByte = a & 0x0F
fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device)
fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2)
return out
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
sf_m, sf_k = a_sf_swizzled.shape
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out
def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale,
m, n, dtype, block_size, device): m, n, dtype, block_size, device):
_, m_k = a_fp4.shape _, m_k = a_fp4.shape
_, n_k = b_fp4.shape _, n_k = b_fp4.shape
assert (m_k == n_k) assert (m_k == n_k)
a_in_dtype = dequantize_to_dtype(a_fp4, a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_sf, a_sf,
a_global_scale, a_global_scale,
dtype=dtype, dtype=dtype,
device=device, device=device,
block_size=block_size) block_size=block_size)
b_in_dtype = dequantize_to_dtype(b_fp4, b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
b_sf, b_sf,
b_global_scale, b_global_scale,
dtype=dtype, dtype=dtype,
device=device, device=device,
block_size=block_size) block_size=block_size)
return torch.matmul(a_in_dtype, b_in_dtype.t()) return torch.matmul(a_in_dtype, b_in_dtype.t())
......
...@@ -8,7 +8,7 @@ from vllm.platforms import current_platform ...@@ -8,7 +8,7 @@ from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16] DTYPES = [torch.bfloat16, torch.float16]
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0 K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
N = [1, 2, 3, 4] N = [1, 2, 3, 4]
SEEDS = [0] SEEDS = [0]
...@@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): ...@@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("m", M + [28672]) # m >= 16 @pytest.mark.parametrize("m", M + [28672]) # m >= 16
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(), @pytest.mark.skipif(
reason="only test for rocm") not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [current_platform.fp8_dtype()]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
silu_and_mul_out = silu_and_mul.forward_native(x)
out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale)
return out
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
out_shape = (x.shape[0], x.shape[1] // 2)
out = torch.empty(out_shape,
dtype=current_platform.fp8_dtype(),
device=x.device)
torch.ops._C.silu_and_mul_quant(out, x, scale)
return out
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
layer = SiluAndMul()
# Make inputs
scale = (torch.randn((1), device=device, dtype=torch.float32))
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
ref_out = ref_impl(layer, x, scale)
ops_out = ops_impl(x, scale)
assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
assert ref_out.shape == ops_out.shape
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale))
...@@ -14,8 +14,8 @@ import torch ...@@ -14,8 +14,8 @@ import torch
# Fixture to set up environment variables and teardown servers after tests # Fixture to set up environment variables and teardown servers after tests
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def setup_servers(): def setup_servers():
if torch.cuda.device_count() < 4: if torch.cuda.device_count() < 2:
pytest.skip("Skipping test: fewer than 4 GPUs available") pytest.skip("Skipping test: fewer than 2 GPUs available")
# Set up environment variables # Set up environment variables
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
......
...@@ -47,7 +47,7 @@ def dist_init(): ...@@ -47,7 +47,7 @@ def dist_init():
temp_file = tempfile.mkstemp()[1] temp_file = tempfile.mkstemp()[1]
backend = "nccl" backend = "nccl"
if current_platform.is_cpu(): if current_platform.is_cpu() or current_platform.is_tpu():
backend = "gloo" backend = "gloo"
init_distributed_environment(world_size=1, init_distributed_environment(world_size=1,
...@@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module: ...@@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
return model return model
@pytest.fixture(scope="session")
def llama_2_7b_base_huggingface_id():
# used as a base model for testing with sql lora adapter
return "meta-llama/Llama-2-7b-hf"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sql_lora_huggingface_id(): def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading. # huggingface repo id is used to test lora runtime downloading.
...@@ -198,6 +204,12 @@ def qwen2vl_lora_files(): ...@@ -198,6 +204,12 @@ def qwen2vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
@pytest.fixture(scope="session")
def qwen25vl_base_huggingface_id():
# used as a base model for testing with qwen25vl lora adapter
return "Qwen/Qwen2.5-VL-3B-Instruct"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def qwen25vl_lora_files(): def qwen25vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon") return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
...@@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch): ...@@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
@pytest.fixture @pytest.fixture
def reset_default_device(): def reset_default_device():
""" """
Some tests, such as `test_punica_ops.py`, explicitly set the Some tests, such as `test_punica_ops.py`, explicitly set the
default device, which can affect subsequent tests. Adding this fixture default device, which can affect subsequent tests. Adding this fixture
helps avoid this problem. helps avoid this problem.
""" """
original_device = torch.get_default_device() original_device = torch.get_default_device()
......
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
VllmConfig)
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.v1.engine.processor import Processor
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
sql_lora_files):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that define additional tokens.
"""
# Setup a base model compatible with the sql_lora_files adapter and
# a known number of tokens in the base model.
model_config = ModelConfig(
model=llama_2_7b_base_huggingface_id,
tokenizer=llama_2_7b_base_huggingface_id,
tokenizer_mode="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
device_config=DeviceConfig(),
lora_config=LoRAConfig(),
)
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
processor = Processor(vllm_config, tokenizer)
lora_request = LoRARequest("1", 1, str(sql_lora_files))
request_id = "1"
prompt = "a prompt"
# tokens added in the lora adapter should not raise an error
lora_token_ids = [32000, 32001, 32002, 32003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=lora_token_ids),
lora_request=lora_request)
# tokens in the base model should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
lora_request=lora_request)
# tokens not in the lora adapter should raise an error
invalid_token_ids = [35000, 35001, 35002, 35003]
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=invalid_token_ids),
lora_request=lora_request)
# tokens in the lora adapter with no lora request should raise an error
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=lora_token_ids),
)
def test_allowed_token_ids_with_lora_adapter_no_vocab(
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that do not define additional tokens.
"""
# Setup a base model compatible with the qwen25vl_lora_files adapter and
# a known number of tokens in the base model.
model_config = ModelConfig(
model=qwen25vl_base_huggingface_id,
tokenizer=qwen25vl_base_huggingface_id,
tokenizer_mode="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
device_config=DeviceConfig(),
lora_config=LoRAConfig(),
)
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
processor = Processor(vllm_config, tokenizer)
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
request_id = "1"
prompt = "a prompt"
# tokens in the base model should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
lora_request=lora_request)
# tokens in the base model with no lora request should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
)
# tokens not in the base model should raise an error
invalid_token_ids = [200000, 200001, 200002, 200003]
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=invalid_token_ids),
lora_request=lora_request)
...@@ -30,7 +30,7 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): ...@@ -30,7 +30,7 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_path = get_adapter_absolute_path(lora_name) lora_path = get_adapter_absolute_path(lora_name)
# lora loading should work for either absolute path and hugggingface id. # lora loading should work for either absolute path and huggingface id.
peft_helper = PEFTHelper.from_local_dir(lora_path, 4096) peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
lora_model = LoRAModel.from_local_checkpoint( lora_model = LoRAModel.from_local_checkpoint(
lora_path, lora_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment