Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
...@@ -10,6 +10,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION ...@@ -10,6 +10,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils.torch_utils import set_random_seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -24,7 +25,7 @@ def generate_test_data( ...@@ -24,7 +25,7 @@ def generate_test_data(
device: torch.device, device: torch.device,
): ):
"""Generate test data for given configuration.""" """Generate test data for given configuration."""
current_platform.seed_everything(42) set_random_seed(42)
# Create 2D positions (3, num_tokens) for multimodal case # Create 2D positions (3, num_tokens) for multimodal case
positions = torch.randint( positions = torch.randint(
0, max_position_embeddings // 4, (3, num_tokens), device=device 0, max_position_embeddings // 4, (3, num_tokens), device=device
...@@ -89,6 +90,7 @@ num_tokens_list = [11, 8192] ...@@ -89,6 +90,7 @@ num_tokens_list = [11, 8192]
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list) @pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope( def test_mrope(
default_vllm_config,
model_name: str, model_name: str,
model_info: MRoPETestInfo, model_info: MRoPETestInfo,
tp_size: int, tp_size: int,
...@@ -158,6 +160,7 @@ def test_mrope( ...@@ -158,6 +160,7 @@ def test_mrope(
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list) @pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope_torch_compile_tracing( def test_mrope_torch_compile_tracing(
default_vllm_config,
model_name: str, model_name: str,
model_info: MRoPETestInfo, model_info: MRoPETestInfo,
tp_size: int, tp_size: int,
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
IS_NEOX_STYLE = [True, False] IS_NEOX_STYLE = [True, False]
DTYPES = [torch.bfloat16, torch.float] DTYPES = [torch.bfloat16, torch.float]
...@@ -62,6 +62,7 @@ TENSORS_SHAPES_FN = [ ...@@ -62,6 +62,7 @@ TENSORS_SHAPES_FN = [
@pytest.mark.parametrize("use_key", USE_KEY) @pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode() @torch.inference_mode()
def test_rotary_embedding( def test_rotary_embedding(
default_vllm_config,
is_neox_style: bool, is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]], tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]],
batch_size: int, batch_size: int,
...@@ -79,7 +80,7 @@ def test_rotary_embedding( ...@@ -79,7 +80,7 @@ def test_rotary_embedding(
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
...@@ -123,7 +124,7 @@ def test_rotary_embedding( ...@@ -123,7 +124,7 @@ def test_rotary_embedding(
@torch.inference_mode() @torch.inference_mode()
def test_rope_module_cache(): def test_rope_module_cache(default_vllm_config):
MAX_POSITIONS = [123, 1234] MAX_POSITIONS = [123, 1234]
ROPE_THETAS = [10000, 1000000] ROPE_THETAS = [10000, 1000000]
ROPE_PARAMETERS = ( ROPE_PARAMETERS = (
......
...@@ -36,6 +36,7 @@ def rotary_embedding_opcheck( ...@@ -36,6 +36,7 @@ def rotary_embedding_opcheck(
@pytest.mark.parametrize("use_key", [True, False]) @pytest.mark.parametrize("use_key", [True, False])
@pytest.mark.parametrize("head_stride_is_contiguous", [True, False]) @pytest.mark.parametrize("head_stride_is_contiguous", [True, False])
def test_rotary_embedding_opcheck( def test_rotary_embedding_opcheck(
default_vllm_config,
dist_init, dist_init,
device, device,
max_position, max_position,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for fused MLA KV-cache write and RoPE fused kernel
"""
import random
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.utils.torch_utils import set_random_seed
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
@pytest.mark.parametrize("is_neox_style", [False, True])
@pytest.mark.parametrize("seq_len", [11, 42])
@pytest.mark.parametrize("qk_rope_head_dim", [64, 128])
@pytest.mark.parametrize("num_q_heads", [128])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("num_blocks", [64])
@pytest.mark.parametrize("block_size", [16, 64, 256])
@pytest.mark.parametrize("seed", [0])
@pytest.mark.parametrize(
"device", [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
)
@torch.inference_mode()
def test_concat_and_cache_mla_rope_fused(
default_vllm_config,
dtype: torch.dtype,
is_neox_style: bool,
seq_len: int,
qk_rope_head_dim: int,
num_q_heads: int,
kv_cache_dtype: str,
kv_lora_rank: int,
num_blocks: int,
block_size: int,
seed: int,
device: str,
max_position: int = 8192,
base: float = 10000,
) -> None:
set_random_seed(seed)
torch.set_default_device(device)
rope = RotaryEmbedding(
qk_rope_head_dim,
qk_rope_head_dim,
max_position,
base,
is_neox_style,
torch.float32,
)
rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (seq_len,))
query = torch.randn(seq_len, num_q_heads, qk_rope_head_dim, dtype=dtype)
key = torch.randn(seq_len, 1, qk_rope_head_dim + kv_lora_rank, dtype=dtype)
k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device)
kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device)
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe)
assert ref_k_pe is not None
ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device)
ref_k_rope = ref_k_pe[..., :qk_rope_head_dim]
total_available_slots = num_blocks * block_size
total_needed_slots = seq_len
assert total_available_slots >= total_needed_slots, "Not enough kv slots!"
slot_mapping_lst = random.sample(range(total_available_slots), total_needed_slots)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
kv_cache_scale = torch.tensor([0.1], dtype=torch.float32, device=device)
kv_cache = torch.zeros(
num_blocks,
block_size,
entry_size,
dtype=torch.uint8 if kv_cache_dtype == "fp8" else dtype,
device=device,
)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(seq_len):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_temp[block_idx, block_offset] = torch.cat((kv_c[i], ref_k_rope[i]), -1)
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(
ref_kv_cache, ref_temp, kv_cache_scale.item(), kv_dtype=kv_cache_dtype
)
else:
ref_kv_cache = ref_temp
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused,
(
positions,
query,
k_pe,
kv_c,
rope.cos_sin_cache,
is_neox_style,
slot_mapping,
kv_cache,
kv_cache_dtype,
kv_cache_scale,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla_rope_fused(
positions,
query,
k_pe,
kv_c,
rope.cos_sin_cache,
is_neox_style,
slot_mapping,
kv_cache,
kv_cache_dtype,
kv_cache_scale,
)
if kv_cache_dtype == "fp8":
result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
ops.convert_fp8(
result_temp,
kv_cache.contiguous(),
kv_cache_scale.item(),
kv_dtype=kv_cache_dtype,
)
expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
ops.convert_fp8(
expected_temp, ref_kv_cache, kv_cache_scale.item(), kv_dtype=kv_cache_dtype
)
torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
else:
torch.testing.assert_close(kv_cache, ref_kv_cache)
torch.testing.assert_close(
query, ref_q_pe, atol=get_default_atol(query), rtol=get_default_rtol(query)
)
...@@ -12,8 +12,8 @@ from vllm.distributed.parallel_state import ( ...@@ -12,8 +12,8 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
...@@ -68,7 +68,7 @@ def mixer2_gated_norm_tensor_parallel( ...@@ -68,7 +68,7 @@ def mixer2_gated_norm_tensor_parallel(
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str,
): ):
current_platform.seed_everything(0) set_random_seed(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
......
...@@ -7,12 +7,12 @@ import torch ...@@ -7,12 +7,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
causal_conv1d_update, causal_conv1d_update,
) )
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_ref( def causal_conv1d_ref(
...@@ -154,7 +154,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity ...@@ -154,7 +154,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
if itype == torch.bfloat16: if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2 rtol, atol = 1e-2, 5e-2
# set seed # set seed
current_platform.seed_everything(0) set_random_seed(0)
batch = 2 batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone() x_ref = x.clone()
...@@ -201,7 +201,7 @@ def test_causal_conv1d_update_with_batch_gather( ...@@ -201,7 +201,7 @@ def test_causal_conv1d_update_with_batch_gather(
rtol, atol = 1e-2, 5e-2 rtol, atol = 1e-2, 5e-2
# set seed # set seed
current_platform.seed_everything(0) set_random_seed(0)
padding = 5 if with_padding else 0 padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding padded_batch_size = batch_size + padding
...@@ -278,7 +278,7 @@ def test_causal_conv1d_varlen( ...@@ -278,7 +278,7 @@ def test_causal_conv1d_varlen(
if itype == torch.bfloat16: if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2 rtol, atol = 1e-2, 5e-2
# set seed # set seed
current_platform.seed_everything(0) set_random_seed(0)
seqlens = [] seqlens = []
batch_size = batch batch_size = batch
padding = 3 if with_padding else 0 padding = 3 if with_padding else 0
......
...@@ -8,12 +8,12 @@ from einops import rearrange, repeat ...@@ -8,12 +8,12 @@ from einops import rearrange, repeat
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401 from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_scan_fn,
selective_state_update, selective_state_update,
) )
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
def selective_state_update_ref( def selective_state_update_ref(
...@@ -271,7 +271,7 @@ def test_selective_scan( ...@@ -271,7 +271,7 @@ def test_selective_scan(
rtolw = max(rtolw, rtol) rtolw = max(rtolw, rtol)
atolw = max(atolw, atol) atolw = max(atolw, atol)
# set seed # set seed
current_platform.seed_everything(0) set_random_seed(0)
batch_size = 1 batch_size = 1
dim = 4 dim = 4
dstate = 8 dstate = 8
...@@ -401,7 +401,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): ...@@ -401,7 +401,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
if torch.version.hip: if torch.version.hip:
atol *= 2 atol *= 2
# set seed # set seed
current_platform.seed_everything(0) set_random_seed(0)
batch_size = 1 batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype) x = torch.randn(batch_size, dim, device=device, dtype=itype)
...@@ -438,7 +438,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len): ...@@ -438,7 +438,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
if torch.version.hip: if torch.version.hip:
atol *= 2 atol *= 2
# set seed # set seed
current_platform.seed_everything(0) set_random_seed(0)
batch_size = 4 batch_size = 4
token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
total_tokens = int(token_counts.sum().item()) total_tokens = int(token_counts.sum().item())
...@@ -857,7 +857,7 @@ def test_selective_state_update_with_num_accepted_tokens( ...@@ -857,7 +857,7 @@ def test_selective_state_update_with_num_accepted_tokens(
if torch.version.hip: if torch.version.hip:
atol *= 2 atol *= 2
current_platform.seed_everything(0) set_random_seed(0)
batch_size = 4 batch_size = 4
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
...@@ -983,7 +983,7 @@ def test_selective_state_update_varlen_with_num_accepted( ...@@ -983,7 +983,7 @@ def test_selective_state_update_varlen_with_num_accepted(
if torch.version.hip: if torch.version.hip:
atol *= 2 atol *= 2
current_platform.seed_everything(0) set_random_seed(0)
batch_size = 4 batch_size = 4
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
......
...@@ -9,7 +9,7 @@ from einops import rearrange, repeat ...@@ -9,7 +9,7 @@ from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen, mamba_chunk_scan_combined_varlen,
) )
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.mamba2_attn import compute_varlen_chunk_metadata from vllm.v1.attention.backends.mamba2_attn import compute_varlen_chunk_metadata
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
...@@ -82,7 +82,7 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): ...@@ -82,7 +82,7 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"): def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"):
current_platform.seed_everything(0) set_random_seed(0)
A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device)) A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
dt = F.softplus( dt = F.softplus(
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4 torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4
......
...@@ -258,16 +258,16 @@ class Config: ...@@ -258,16 +258,16 @@ class Config:
f"{self.fe_supported_types()}." f"{self.fe_supported_types()}."
) )
# Check block quanization support # Check block quantization support
is_block_quatized = self.quant_block_shape is not None is_block_quantized = self.quant_block_shape is not None
if is_block_quatized and self.quant_dtype is None: if is_block_quantized and self.quant_dtype is None:
return False, "No block quantization support." return False, "No block quantization support."
if is_block_quatized and not self.is_block_quant_supported(): if is_block_quantized and not self.is_block_quant_supported():
return False, "Mismatched block quantization support." return False, "Mismatched block quantization support."
# deep_gemm only works with block-quantized # deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized: if self.needs_deep_gemm() and not is_block_quantized:
return False, "Needs DeepGEMM but not block quantized." return False, "Needs DeepGEMM but not block quantized."
# Check dependencies (turn into asserts?) # Check dependencies (turn into asserts?)
......
...@@ -10,7 +10,7 @@ from tqdm import tqdm ...@@ -10,7 +10,7 @@ from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
from .common import ( from .common import (
Config, Config,
...@@ -40,7 +40,7 @@ def rank_worker( ...@@ -40,7 +40,7 @@ def rank_worker(
config: Config, config: Config,
weights: WeightTensors, weights: WeightTensors,
): ):
current_platform.seed_everything(pgi.rank) set_random_seed(pgi.rank)
# sanity check # sanity check
from vllm import envs from vllm import envs
......
...@@ -9,7 +9,7 @@ from typing import Any ...@@ -9,7 +9,7 @@ from typing import Any
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
from .common import Config, RankTensors, WeightTensors, make_modular_kernel from .common import Config, RankTensors, WeightTensors, make_modular_kernel
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
...@@ -82,7 +82,7 @@ def rank_worker( ...@@ -82,7 +82,7 @@ def rank_worker(
config: Config, config: Config,
weights: WeightTensors, weights: WeightTensors,
): ):
current_platform.seed_everything(pgi.rank) set_random_seed(pgi.rank)
# sanity check # sanity check
from vllm import envs from vllm import envs
......
...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( ...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl from vllm.triton_utils import tl
from vllm.utils.torch_utils import set_random_seed
MNK_FACTORS = [ MNK_FACTORS = [
(1, 128, 128), (1, 128, 128),
...@@ -115,7 +116,7 @@ def test_batched_mm( ...@@ -115,7 +116,7 @@ def test_batched_mm(
): ):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89, """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware.""" and those tests will be skipped on unsupported hardware."""
current_platform.seed_everything(7) set_random_seed(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn use_fp8_w8a8 = dtype == torch.float8_e4m3fn
...@@ -252,7 +253,7 @@ def test_fused_moe_batched_experts( ...@@ -252,7 +253,7 @@ def test_fused_moe_batched_experts(
): ):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89, """Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware.""" and those tests will be skipped on unsupported hardware."""
current_platform.seed_everything(7) set_random_seed(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn use_fp8_w8a8 = dtype == torch.float8_e4m3fn
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
EXPERT_NUM = [
8,
]
HIDDEN_DIM = [128, 2880]
INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = ["silu", "swigluoai"]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]
def ref_fused_moe(
input: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
) -> torch.Tensor:
len_experts = w13.size(0)
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = input[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx].float()
curr_w13 = w13[i].float()
curr_w2 = w2[i].float()
curr_w13_bias = None
if w13_bias is not None:
curr_w13_bias = w13_bias[i].float()
curr_w2_bias = None
if w2_bias is not None:
curr_w2_bias = w2_bias[i].float()
gate_up = torch.nn.functional.linear(
tokens_for_this_expert, curr_w13, curr_w13_bias
)
# Note: to simulate the kernel implementation
gate_up = (
_CPU_MOE_ACT[activation]
.forward_native(gate_up)
.to(dtype=input.dtype)
.float()
)
expert_out = torch.nn.functional.linear(gate_up, curr_w2, curr_w2_bias)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(input.dtype)
)
return final_out
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("expert_num", EXPERT_NUM)
@pytest.mark.parametrize("hidden_size", HIDDEN_DIM)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_DIM)
@pytest.mark.parametrize("use_bias", USE_BIAS)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("act", ACT)
@pytest.mark.parametrize("isa", ISA)
def test_cpu_fused_moe(
default_vllm_config,
batch_size: int,
expert_num: int,
hidden_size: int,
intermediate_size: int,
use_bias: bool,
dtype: torch.dtype,
act: str,
isa: str,
):
set_random_seed(0)
topk_num = max(expert_num // 2, 1)
up_dim = 2 * intermediate_size
input = torch.randn((batch_size, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
w13 = torch.randn((expert_num, up_dim, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
w2 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype) / (
0.5 * intermediate_size**0.5
)
router_logits = torch.randn((batch_size, expert_num), dtype=dtype)
w13_bias = None
w2_bias = None
if use_bias:
w13_bias = torch.randn((expert_num, up_dim), dtype=dtype) / (0.5 * up_dim**0.5)
w2_bias = torch.randn((expert_num, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
score = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk_num)
topk_ids = topk_ids.to(torch.int32)
ref_output = ref_fused_moe(
input,
w13,
w2,
w13_bias,
w2_bias,
topk_weight,
topk_ids,
act,
)
packed_w13 = cpu_prepack_moe_weight(w13, isa)
packed_w2 = cpu_prepack_moe_weight(w2, isa)
output = cpu_fused_moe(
input,
packed_w13,
packed_w2,
w13_bias,
w2_bias,
topk_weight,
topk_ids,
act,
isa,
)
atol, rtol = get_default_atol(output), get_default_rtol(output)
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# DeepGEMM Style Cutlass Grouped GEMM Test
# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py
import random
import pytest
import torch
from tests.kernels.moe.utils import per_token_cast_to_fp8
from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import per_block_cast_to_fp8
from vllm.utils.math_utils import cdiv
@pytest.mark.parametrize(
"num_groups, expected_m_per_group, k, n",
[
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),
(8, 4096, 7168, 4096),
(8, 4096, 2048, 7168),
(32, 1024, 7168, 4096),
(32, 1024, 2048, 7168),
],
)
@pytest.mark.parametrize("out_dtype", [torch.float16])
@pytest.mark.skipif(
(lambda x: x is None or x.to_int() != 100)(
current_platform.get_device_capability()
),
reason="Block Scaled Grouped GEMM is only supported on SM100.",
)
def test_cutlass_grouped_gemm(
num_groups: int,
expected_m_per_group: int,
k: int,
n: int,
out_dtype: torch.dtype,
):
device = "cuda"
alignment = 128
group_ms = [
int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)
]
m = sum([cdiv(m, alignment) * alignment for m in group_ms])
x = torch.randn((m, k), device=device, dtype=out_dtype)
y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype)
out = torch.empty((m, n), device=device, dtype=out_dtype)
ref_out = torch.randn((m, n), device=device, dtype=out_dtype)
ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m]
pb_size = []
for i in range(num_groups):
pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k])
problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32)
expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (
torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float
),
)
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
for i in range(num_groups):
a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]]
a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]]
b = y_fp8[0][i].t()
b_scale = y_fp8[1][i].t()
baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline
ops.cutlass_blockwise_scaled_grouped_mm(
out,
x_fp8[0],
y_fp8[0],
x_fp8[1],
y_fp8[1],
problem_sizes,
expert_offsets[:-1],
)
torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3)
...@@ -22,13 +22,13 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -22,13 +22,13 @@ from vllm.model_executor.layers.fused_moe.config import (
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
) )
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test from ...utils import multi_gpu_test
...@@ -367,7 +367,7 @@ def _test_deepep_deepgemm_moe( ...@@ -367,7 +367,7 @@ def _test_deepep_deepgemm_moe(
device = torch.device(f"cuda:{pgi.local_rank}") device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device) init_workspace_manager(device)
current_platform.seed_everything(pgi.rank) set_random_seed(pgi.rank)
w1 = w1.to(device=torch.cuda.current_device()) w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device()) w2 = w2.to(device=torch.cuda.current_device())
...@@ -456,7 +456,7 @@ def test_ht_deepep_deepgemm_moe( ...@@ -456,7 +456,7 @@ def test_ht_deepep_deepgemm_moe(
""" """
m, n, k = mnk m, n, k = mnk
current_platform.seed_everything(7) set_random_seed(7)
if topk > num_experts: if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}") pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
...@@ -531,7 +531,7 @@ def test_ll_deepep_deepgemm_moe( ...@@ -531,7 +531,7 @@ def test_ll_deepep_deepgemm_moe(
assert not is_deep_gemm_e8m0_used() assert not is_deep_gemm_e8m0_used()
m, n, k = mnk m, n, k = mnk
current_platform.seed_everything(7) set_random_seed(7)
if topk > num_experts: if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}") pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
......
...@@ -20,8 +20,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularK ...@@ -20,8 +20,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularK
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep from vllm.utils.import_utils import has_deep_ep
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test from ...utils import multi_gpu_test
...@@ -446,7 +446,7 @@ def test_deep_ep_moe( ...@@ -446,7 +446,7 @@ def test_deep_ep_moe(
low_latency_mode = False low_latency_mode = False
use_fp8_dispatch = False use_fp8_dispatch = False
current_platform.seed_everything(7) set_random_seed(7)
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
...@@ -507,7 +507,7 @@ def test_low_latency_deep_ep_moe( ...@@ -507,7 +507,7 @@ def test_low_latency_deep_ep_moe(
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}" f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
) )
current_platform.seed_everything(7) set_random_seed(7)
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
......
...@@ -11,17 +11,23 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -11,17 +11,23 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, apply_fi_trtllm_fp8_per_tensor_moe,
flashinfer_cutlass_moe_fp8, register_scales_for_trtllm_fp8_per_tensor_moe,
register_moe_scaling_factors, rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31, swap_w13_to_w31,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8 from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
try: try:
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
...@@ -84,7 +90,7 @@ class TestData: ...@@ -84,7 +90,7 @@ class TestData:
@staticmethod @staticmethod
def make_moe_tensors_8bit( def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu" m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
) -> "TestData": ) -> "TestData":
is_gated = activation != "relu2_no_mul" is_gated = activation != "relu2_no_mul"
...@@ -102,6 +108,7 @@ class TestData: ...@@ -102,6 +108,7 @@ class TestData:
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
layer = torch.nn.Module() layer = torch.nn.Module()
layer.orig_dtype = torch.bfloat16
layer.w13_weight = w13_quantized.clone() layer.w13_weight = w13_quantized.clone()
layer.w2_weight = w2_quantized.clone() layer.w2_weight = w2_quantized.clone()
layer.w13_input_scale = a1_scale layer.w13_input_scale = a1_scale
...@@ -114,20 +121,27 @@ class TestData: ...@@ -114,20 +121,27 @@ class TestData:
pcp_size=1, pcp_size=1,
dp_size=1, dp_size=1,
ep_size=1, ep_size=1,
tp_rank=1, tp_rank=0,
pcp_rank=1, pcp_rank=0,
dp_rank=1, dp_rank=0,
ep_rank=1, ep_rank=0,
use_ep=False, use_ep=False,
all2all_backend="naive", all2all_backend="naive",
) )
register_moe_scaling_factors(layer)
# flashinfer expects swapped rows for w13 # flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder: if is_trtllm:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
layer.w13_weight, layer.w2_weight
)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
layer.custom_routing_function = Llama4MoE.custom_routing_function layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n layer.intermediate_size_per_partition = n
layer.ep_rank = 0 layer.ep_rank = 0
...@@ -158,10 +172,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -158,10 +172,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
): ):
if not current_platform.has_device_capability(100): if not current_platform.has_device_capability(100):
pytest.skip("Test is only supported for sm >= 100") pytest.skip("Test is only supported for sm >= 100")
current_platform.seed_everything(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = Llama4MoE.custom_routing_function( topk_weights, topk_ids = Llama4MoE.custom_routing_function(
...@@ -193,7 +207,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -193,7 +207,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
quant_config=quant_config, quant_config=quant_config,
) )
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
layer=td.layer, layer=td.layer,
hidden_states=td.hidden_states, hidden_states=td.hidden_states,
router_logits=score, router_logits=score,
...@@ -222,11 +236,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -222,11 +236,11 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
monkeypatch, monkeypatch,
workspace_init, workspace_init,
): ):
current_platform.seed_everything(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit( td = TestData.make_moe_tensors_8bit(
m, k, n, e, reorder=False, activation=activation m, k, n, e, is_trtllm=False, activation=activation
) )
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
...@@ -271,17 +285,34 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -271,17 +285,34 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer td.layer.quant_method = td.layer
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=quant_config.is_block_quantized
),
FlashInferExperts(
out_dtype=td.layer.orig_dtype,
quant_config=quant_config,
ep_rank=td.layer.moe_parallel_config.ep_rank,
ep_size=td.layer.moe_parallel_config.ep_size,
tp_rank=td.layer.moe_parallel_config.tp_rank,
tp_size=td.layer.moe_parallel_config.tp_size,
use_dp=False,
use_deepseek_fp8_block_scale=False,
),
)
flashinfer_cutlass_output = kernel(
td.hidden_states, td.hidden_states,
td.layer, td.layer.w13_weight,
td.layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False,
activation=activation, activation=activation,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
) )
torch.testing.assert_close( torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2 output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
) )
...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk ...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import set_random_seed
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
100 100
...@@ -60,7 +61,7 @@ def test_flashinfer_fp4_moe_no_graph( ...@@ -60,7 +61,7 @@ def test_flashinfer_fp4_moe_no_graph(
activation: str, activation: str,
workspace_init, workspace_init,
): ):
current_platform.seed_everything(7) set_random_seed(7)
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
): ):
......
...@@ -8,11 +8,18 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`. ...@@ -8,11 +8,18 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
import pytest import pytest
import torch import torch
from vllm.config import (
CompilationConfig,
VllmConfig,
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
fused_grouped_topk, fused_grouped_topk,
grouped_topk,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -27,7 +34,8 @@ from vllm.platforms import current_platform ...@@ -27,7 +34,8 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("topk_group", [2]) @pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) @pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("bias_dtype", [torch.float32])
def test_grouped_topk( def test_grouped_topk(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
n_token: int, n_token: int,
...@@ -39,26 +47,33 @@ def test_grouped_topk( ...@@ -39,26 +47,33 @@ def test_grouped_topk(
topk_group: int, topk_group: int,
scoring_func: str, scoring_func: str,
routed_scaling_factor: float, routed_scaling_factor: float,
dtype: torch.dtype, input_dtype: torch.dtype,
bias_dtype: torch.dtype,
): ):
current_platform.seed_everything(0) vllm_config = VllmConfig(
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda") compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"])
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
e_score_correction_bias = torch.randn(
(n_expert,), dtype=torch.float32, device="cuda"
) )
get_cached_compilation_config.cache_clear()
set_random_seed(0)
hidden_states = torch.randn((n_token, n_hidden), dtype=input_dtype, device="cuda")
gating_output = torch.randn((n_token, n_expert), dtype=input_dtype, device="cuda")
e_score_correction_bias = torch.randn((n_expert,), dtype=bias_dtype, device="cuda")
with monkeypatch.context() as m: with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
baseline_topk_weights, baseline_topk_ids = grouped_topk( grouped_topk = GroupedTopk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk, topk=topk,
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
)
assert grouped_topk._forward_method.__name__ == "forward_cuda"
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )
......
...@@ -15,7 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config ...@@ -15,7 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless, set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from .modular_kernel_tools.common import ( from .modular_kernel_tools.common import (
...@@ -82,7 +82,7 @@ def rank_worker( ...@@ -82,7 +82,7 @@ def rank_worker(
device = torch.device(f"cuda:{pgi.local_rank}") device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device) init_workspace_manager(device)
current_platform.seed_everything(pgi.rank) set_random_seed(pgi.rank)
# sanity check # sanity check
from vllm import envs from vllm import envs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment