Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
...@@ -28,7 +28,34 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -28,7 +28,34 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
assert (backend.get_name() == "ROCM_FLASH" assert (backend.get_name() == "ROCM_FLASH"
or backend.get_name() == "TRITON_ATTN_VLLM_V1") or backend.get_name() == "TRITON_ATTN_VLLM_V1")
# mla test for deepseek related # MLA test for deepseek related
# change the attention backend to triton MLA
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
assert backend.get_name() == "TRITON_MLA"
# If attention backend is None
# If use_mla is true
# The selected backend is triton MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True) False, True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# change the attention backend to AITER MLA
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
assert backend.get_name() == "ROCM_AITER_MLA"
# If attention backend is None
# If use_mla is true
# If VLLM_ROCM_USE_AITER is enabled
# The selected backend is ROCM_AITER_MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
assert backend.get_name() == "ROCM_AITER_MLA"
...@@ -5,6 +5,7 @@ import random ...@@ -5,6 +5,7 @@ import random
import pytest import pytest
import torch import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, MulAndSilu, GeluAndMul, MulAndSilu,
...@@ -12,8 +13,6 @@ from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, ...@@ -12,8 +13,6 @@ from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
SiluAndMul) SiluAndMul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing D = [512, 13824] # Arbitrary values for testing
......
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
Tests for miscellaneous utilities Tests for miscellaneous utilities
""" """
import pytest
import torch import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
def test_convert_fp8_opcheck(): def test_convert_fp8_opcheck():
...@@ -16,10 +14,12 @@ def test_convert_fp8_opcheck(): ...@@ -16,10 +14,12 @@ def test_convert_fp8_opcheck():
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8")) opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@pytest.mark.skipif(not current_platform.is_cuda(), # TODO: Add this back, currently fails with
reason="Only supported for CUDA") # csrc/cuda_utils_kernels.cu:15 'invalid argument'
def test_cuda_utils_opcheck(): # @pytest.mark.skipif(not current_platform.is_cuda(),
opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0)) # reason="Only supported for CUDA")
opcheck( # def test_cuda_utils_opcheck():
torch.ops._C_cuda_utils. # opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
get_max_shared_memory_per_block_device_attribute, (0, )) # opcheck(
# torch.ops._C_cuda_utils.
# get_max_shared_memory_per_block_device_attribute, (0, ))
...@@ -6,11 +6,10 @@ from typing import Callable, Optional ...@@ -6,11 +6,10 @@ from typing import Callable, Optional
import pytest import pytest
import torch import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE = [True, False] IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 112, 120, 256] HEAD_SIZES = [64, 80, 112, 120, 256]
......
...@@ -5,6 +5,8 @@ import torch ...@@ -5,6 +5,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_seq_idx_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined) mamba_chunk_scan_combined)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -160,14 +162,14 @@ def generate_continous_batched_examples(example_lens_by_batch, ...@@ -160,14 +162,14 @@ def generate_continous_batched_examples(example_lens_by_batch,
# get the metadata # get the metadata
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
sed_idx = torch.zeros(cu_seqlens[-1], seq_idx = torch.zeros(cu_seqlens[-1],
dtype=torch.int32, dtype=torch.int32,
device=cu_seqlens.device) device=cu_seqlens.device)
for i, (srt, end) in enumerate(zip( for i, (srt, end) in enumerate(zip(
cu_seqlens, cu_seqlens,
cu_seqlens[1:], cu_seqlens[1:],
)): )):
sed_idx[srt:end] = i seq_idx[srt:end] = i
# for cont batch # for cont batch
if IND_E is None: if IND_E is None:
...@@ -177,7 +179,7 @@ def generate_continous_batched_examples(example_lens_by_batch, ...@@ -177,7 +179,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
...@@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None states = None
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, for Y_min, cu_seqlens, seq_idx, (A, dt, X, B,
C) in generate_continous_batched_examples( C) in generate_continous_batched_examples(
cases, num_examples, seqlen, cases, num_examples, seqlen,
last_taken, exhausted, n_heads, last_taken, exhausted, n_heads,
d_head, itype): d_head, itype):
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
Y, new_states = mamba_chunk_scan_combined( Y, new_states = mamba_chunk_scan_combined(
X, X,
dt, dt,
...@@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
chunk_size, chunk_size,
D=None, D=None,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
seq_idx=sed_idx, seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True, return_varlen_states=True,
initial_states=states, initial_states=states,
) )
......
# SPDX-License-Identifier: Apache-2.0
import dataclasses
from typing import Optional
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.platforms import current_platform
NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 3072, 1536),
(224, 1024, 1024),
(224, 1024, 1536),
(224, 3072, 1024),
(224, 3072, 1536),
]
@dataclasses.dataclass
class MOETensors:
a: torch.Tensor
w1: torch.Tensor
w2: torch.Tensor
ab_strides1: torch.Tensor
c_strides1: torch.Tensor
ab_strides2: torch.Tensor
c_strides2: torch.Tensor
@staticmethod
def make_moe_tensors(m: int, k: int, n: int, e: int,
dtype: torch.dtype) -> "MOETensors":
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
return MOETensors(a=a,
w1=w1,
w2=w2,
ab_strides1=ab_strides1,
c_strides1=c_strides1,
ab_strides2=ab_strides2,
c_strides2=c_strides2)
@dataclasses.dataclass
class MOETensors8Bit(MOETensors):
# quantized
a_q: Optional[torch.Tensor] = None # a -> a_q
w1_q: Optional[torch.Tensor] = None # w1 -> w1_q
w2_q: Optional[torch.Tensor] = None # w2 -> w2_q
a_scale: Optional[torch.Tensor] = None
w1_scale: Optional[torch.Tensor] = None
w2_scale: Optional[torch.Tensor] = None
# dequantized
a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d
w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
@staticmethod
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
per_act_token: bool,
per_out_channel: bool) -> "MOETensors8Bit":
dtype = torch.half
q_dtype = torch.float8_e4m3fn
moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)
# a -> a_q, w1 -> w1_q, w2 -> w2_q
n_b_scales = 2 * n if per_out_channel else 1
k_b_scales = k if per_out_channel else 1
# Get the right scale for tests.
_, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
a_scale,
use_per_token_if_dynamic=per_act_token)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
moe_tensors_fp16.w1[expert],
use_per_token_if_dynamic=per_out_channel)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
moe_tensors_fp16.w2[expert],
use_per_token_if_dynamic=per_out_channel)
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
a_d = a_q.float().mul(a_scale).to(dtype)
w1_d = torch.empty_like(moe_tensors_fp16.w1)
w2_d = torch.empty_like(moe_tensors_fp16.w2)
for expert in range(e):
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
return MOETensors8Bit(a=moe_tensors_fp16.a,
w1=moe_tensors_fp16.w1,
w2=moe_tensors_fp16.w2,
ab_strides1=moe_tensors_fp16.ab_strides1,
c_strides1=moe_tensors_fp16.c_strides1,
ab_strides2=moe_tensors_fp16.ab_strides2,
c_strides2=moe_tensors_fp16.c_strides2,
a_q=a_q,
w1_q=w1_q,
w2_q=w2_q,
a_scale=a_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
a_d=a_d,
w1_d=w1_d,
w2_d=w2_d)
def run_with_expert_maps(num_experts: int, num_local_experts: int,
**cutlass_moe_kwargs):
def slice_experts():
slice_params = [
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
"c_strides2", "w1_scale", "w2_scale"
]
full_tensors = {
k: v
for k, v in cutlass_moe_kwargs.items()
if k in slice_params and k in cutlass_moe_kwargs
}
for i in range(0, num_experts, num_local_experts):
s, e = i, i + num_local_experts
# make expert map
expert_map = [-1] * num_experts
expert_map[s:e] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map,
dtype=torch.int32,
device="cuda")
# update cutlass moe arg with expert_map
cutlass_moe_kwargs["expert_map"] = expert_map
# update cutlass moe arg tensors
for k, t in full_tensors.items():
cutlass_moe_kwargs[k] = t[s:e]
yield cutlass_moe_kwargs
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
for kwargs in slice_experts():
out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)
return out_tensor
def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
moe_tensors.w2_scale, moe_tensors.a_scale
]
])
kwargs = {
'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids_': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides2': moe_tensors.c_strides2,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale
}
num_experts = moe_tensors.w1.size(0)
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
return cutlass_moe_fp8(**kwargs)
assert num_local_experts is not None
return run_with_expert_maps(
num_experts,
num_local_experts, # type: ignore[arg-type]
**kwargs)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_8_bit_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a,
score,
topk,
renormalize=False)
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
torch.testing.assert_close(triton_output,
cutlass_output,
atol=5e-2,
rtol=1e-2)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_8_bit_cuda_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
dtype = torch.half
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(mt.a,
score,
topk,
renormalize=False)
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(triton_output,
cutlass_output,
atol=9e-2,
rtol=1e-2)
@pytest.mark.parametrize("m", [64])
@pytest.mark.parametrize("n", [1024])
@pytest.mark.parametrize("k", [4096])
@pytest.mark.parametrize("e", [16])
@pytest.mark.parametrize("topk", [1, 8])
@pytest.mark.parametrize("per_act_token", [True])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_8_bit_EP(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a,
score,
topk,
renormalize=False)
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
assert e % ep_size == 0, "Cannot distribute experts evenly"
cutlass_output = run_8_bit(mt,
topk_weights,
topk_ids,
num_local_experts=e // ep_size)
torch.testing.assert_close(triton_output,
cutlass_output,
atol=5e-2,
rtol=1e-2)
...@@ -11,16 +11,14 @@ from transformers import MixtralConfig ...@@ -11,16 +11,14 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
torch_moe, torch_moe_single) torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe) fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize) awq_marlin_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights) quantize_weights)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
...@@ -287,14 +285,17 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -287,14 +285,17 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol=mixtral_moe_tol[dtype]) atol=mixtral_moe_tol[dtype])
@pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("ep_size", [1, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe( def test_fused_marlin_moe(
...@@ -303,9 +304,12 @@ def test_fused_marlin_moe( ...@@ -303,9 +304,12 @@ def test_fused_marlin_moe(
k: int, k: int,
e: int, e: int,
topk: int, topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int, group_size: int,
act_order: bool, act_order: bool,
num_bits: int, num_bits: int,
has_zp: bool,
is_k_full: bool, is_k_full: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
...@@ -316,75 +320,110 @@ def test_fused_marlin_moe( ...@@ -316,75 +320,110 @@ def test_fused_marlin_moe(
return return
if group_size in (k, n): if group_size in (k, n):
return return
if has_zp:
return
else: else:
if not is_k_full: if not is_k_full:
return return
quant_type = (scalar_types.uint4b8 if has_zp:
if num_bits == 4 else scalar_types.uint8b128) # we don't build kernel for int8 with zero
dtype = torch.float16 if num_bits == 8:
return
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
w_ref1_l = [] w_ref1_l = []
qweight1_l = [] qweight1_l = []
scales1_l = [] scales1_l = []
zeros1_l = []
g_idx1_l = [] g_idx1_l = []
sort_indices1_l = [] sort_indices1_l = []
for i in range(w1.shape[0]): for i in range(w1.shape[0]):
test_perm = torch.randperm(k) if has_zp:
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, w1[i].transpose(1, 0), quant_type, group_size)
test_perm)
w_ref1_l.append(w_ref1) w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1) qweight1_l.append(qweight1)
scales1_l.append(scales1) scales1_l.append(scales1)
g_idx1_l.append(g_idx1) zeros1_l.append(zeros1)
sort_indices1_l.append(sort_indices1) else:
test_perm = torch.randperm(k)
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l) w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous() qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l) scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l) g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = [] w_ref2_l = []
qweight2_l = [] qweight2_l = []
scales2_l = [] scales2_l = []
zeros2_l = []
g_idx2_l = [] g_idx2_l = []
sort_indices2_l = [] sort_indices2_l = []
for i in range(w2.shape[0]): for i in range(w2.shape[0]):
test_perm = torch.randperm(n) if has_zp:
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, w2[i].transpose(1, 0), quant_type, group_size)
test_perm)
w_ref2_l.append(w_ref2) w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2) qweight2_l.append(qweight2)
scales2_l.append(scales2) scales2_l.append(scales2)
g_idx2_l.append(g_idx2) zeros2_l.append(zeros2)
sort_indices2_l.append(sort_indices2) else:
test_perm = torch.randperm(n)
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l) w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous() qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l) scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l) g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False) topk_weights, topk_ids = fused_topk(a, score, topk, False)
triton_output = fused_moe( torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = torch.ops.vllm.fused_marlin_moe( marlin_output = torch.ops.vllm.fused_marlin_moe(
a, a,
qweight1, qweight1,
...@@ -394,111 +433,91 @@ def test_fused_marlin_moe( ...@@ -394,111 +433,91 @@ def test_fused_marlin_moe(
score, score,
topk_weights, topk_weights,
topk_ids, topk_ids,
global_num_experts=e,
expert_map=e_map,
g_idx1=g_idx1, g_idx1=g_idx1,
g_idx2=g_idx2, g_idx2=g_idx2,
sort_indices1=sort_indices1, sort_indices1=sort_indices1,
sort_indices2=sort_indices2, sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=num_bits, num_bits=num_bits,
is_k_full=is_k_full, is_k_full=is_k_full)
)
assert compute_max_diff(marlin_output, triton_output) < 4e-2
if ops.supports_moe_ops:
token_expert_indicies = torch.empty(m,
topk,
dtype=torch.int32,
device=a.device)
opcheck(torch.ops._moe_C.topk_softmax, (
topk_weights,
topk_ids,
token_expert_indicies,
score.float(),
))
block_size_m = 4
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
e)
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16 torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
zp = torch.empty((0, 0),
dtype=dtype,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
m, 2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skip("This test is here for the sake of debugging, " @pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.") "don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512]) @pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
def test_single_marlin_moe_multiply( dtype: torch.dtype, group_size: int,
m: int, act_order: bool, num_bits: int,
n: int, has_zp: bool, is_k_full: bool):
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
# Filter act_order # Filter act_order
if act_order: if act_order:
if group_size == -1: if group_size == -1:
return return
if group_size == k: if group_size in (k, n):
return
if has_zp:
return return
else: else:
if not is_k_full: if not is_k_full:
return return
quant_type = (scalar_types.uint4b8 if has_zp:
if num_bits == 4 else scalar_types.uint8b128) quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
dtype = torch.float16 else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = [] w_ref_l = []
qweights_l = [] qweight_l = []
scales_l = [] scales_l = []
zeros_l = []
g_idx_l = [] g_idx_l = []
sort_indices_l = [] sort_indices_l = []
for i in range(w.shape[0]): for i in range(w.shape[0]):
test_perm = torch.randperm(k) if has_zp:
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref)
qweights_l.append(qweight) w_ref_l.append(w_ref.T)
scales_l.append(scales) qweight_l.append(qweight)
g_idx_l.append(g_idx) scales_l.append(scales)
sort_indices_l.append(sort_indices) zeros_l.append(zeros)
else:
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
w_ref = stack_and_dev(w_ref_l) w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous() qweight = stack_and_dev(qweight_l).contiguous()
scales = stack_and_dev(scales_l) scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l) g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
sort_indices = stack_and_dev(sort_indices_l) zeros = stack_and_dev(zeros_l) if zeros_l else None
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = torch.ops.vllm.single_marlin_moe( marlin_output = torch.ops.vllm.single_marlin_moe(
...@@ -510,13 +529,14 @@ def test_single_marlin_moe_multiply( ...@@ -510,13 +529,14 @@ def test_single_marlin_moe_multiply(
renormalize=False, renormalize=False,
g_idx=g_idx, g_idx=g_idx,
sort_indices=sort_indices, sort_indices=sort_indices,
w_zeros=zeros,
num_bits=num_bits, num_bits=num_bits,
is_k_full=is_k_full, is_k_full=is_k_full,
) )
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) torch_output = torch_moe_single(a, w_ref, score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2 torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
def test_moe_align_block_size_opcheck(): def test_moe_align_block_size_opcheck():
......
...@@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ...@@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_out = (as_float32_tensor(x) * ref_iscale).clamp( ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
return ref_out, ref_scale.view((1, )) return ref_out, ref_scale.view((1, ))
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, Bs: torch.Tensor, block_size,
output_dtype):
"""This function performs matrix multiplication with block-wise
quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and
fp8 data types.
It takes two input tensors `A` and `B` (int8) with scales `As` and
`Bs` (float32).
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N, )
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
]
B_tiles = [[
B[
j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
]
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment