Unverified Commit c1909e7e authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] MoE refactor (#19636)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Co-authored-by: default avatarElizaWszola <ewszola@redhat.com>
parent b9587750
...@@ -113,6 +113,7 @@ def bench_run( ...@@ -113,6 +113,7 @@ def bench_run(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
per_act_token: bool,
num_repeats: int, num_repeats: int,
): ):
for _ in range(num_repeats): for _ in range(num_repeats):
...@@ -124,7 +125,8 @@ def bench_run( ...@@ -124,7 +125,8 @@ def bench_run(
topk_ids, topk_ids,
w1_scale, w1_scale,
w2_scale, w2_scale,
a1_scale=a_scale, per_act_token,
a1_scale=None,
) )
def run_cutlass_from_graph( def run_cutlass_from_graph(
...@@ -148,7 +150,8 @@ def bench_run( ...@@ -148,7 +150,8 @@ def bench_run(
topk_ids, topk_ids,
w1_scale, w1_scale,
w2_scale, w2_scale,
a1_scale=a_scale, per_act_token,
a1_scale=None,
) )
def run_triton_from_graph( def run_triton_from_graph(
...@@ -227,6 +230,7 @@ def bench_run( ...@@ -227,6 +230,7 @@ def bench_run(
"w2_q": w2_q, "w2_q": w2_q,
"w1_scale": w1_scale, "w1_scale": w1_scale,
"w2_scale": w2_scale, "w2_scale": w2_scale,
"per_act_token": per_act_token,
# cuda graph params # cuda graph params
"cutlass_graph": cutlass_graph, "cutlass_graph": cutlass_graph,
"triton_graph": triton_graph, "triton_graph": triton_graph,
...@@ -287,12 +291,13 @@ def bench_run( ...@@ -287,12 +291,13 @@ def bench_run(
w2_scale, w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
per_act_token,
num_warmup, num_warmup,
) )
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501 stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
......
# SPDX-License-Identifier: Apache-2.0
"""
DeepEP test utilities
"""
import dataclasses
import importlib
import os
import traceback
from typing import Callable, Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
## Parallel Processes Utils
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
nprocs=world_size,
join=True,
)
## DeepEP specific utils
@dataclasses.dataclass
class DeepEPHTArgs:
num_local_experts: int
@dataclasses.dataclass
class DeepEPLLArgs:
max_tokens_per_rank: int
hidden_size: int
num_experts: int
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)
...@@ -2,18 +2,59 @@ ...@@ -2,18 +2,59 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import pytest import pytest
import torch import torch
import triton.language as tl import triton.language as tl
from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, triton_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel) invoke_moe_batched_triton_kernel)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
MNK_FACTORS = [
(1, 128, 128),
(1, 128, 2048),
(1, 512, 512),
(1, 1024, 128),
(1, 1024, 2048),
(32, 128, 128),
(32, 512, 512),
(32, 1024, 2048),
(45, 128, 128),
(45, 128, 2048),
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 128, 128),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 512, 512),
(222, 1024, 128),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclass @dataclass
class BatchedMMConfig: class BatchedMMConfig:
dtype: torch.dtype in_dtype: torch.dtype
quant_dtype: Optional[torch.dtype]
out_dtype: torch.dtype
num_experts: int num_experts: int
max_tokens_per_expert: int max_tokens_per_expert: int
K: int K: int
...@@ -32,79 +73,127 @@ class BatchedMMTensors: ...@@ -32,79 +73,127 @@ class BatchedMMTensors:
A = torch.randn( A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K), (config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda", device="cuda",
dtype=config.dtype) / 10 dtype=config.in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K), B = torch.randn((config.num_experts, config.N, config.K),
device="cuda", device="cuda",
dtype=config.dtype) dtype=config.in_dtype)
C = torch.zeros( C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N), (config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda", device="cuda",
dtype=config.dtype) dtype=config.out_dtype)
num_expert_tokens = torch.randint(low=0, num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert, high=config.max_tokens_per_expert,
size=(config.num_experts, ), size=(config.num_experts, ),
device="cuda", device="cuda",
dtype=torch.int32) dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, return BatchedMMTensors(A, B, C, num_expert_tokens)
num_expert_tokens: torch.Tensor) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@pytest.mark.parametrize("num_experts", [16, 32]) @pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("max_tokens_per_expert", @pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512]) [32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024]) @pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024]) @pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("per_act_token_quant", [False])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype): N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool):
current_platform.seed_everything(7)
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) use_fp8_w8a8 = dtype == torch.float8_e4m3fn
tensors = BatchedMMTensors.make_tensors(config)
test_output = tensors.C if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
ref_output = test_output.clone() pytest.skip("Don't test blocking for non-quantized types.")
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization test.")
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
num_expert_tokens = torch.randint(low=0,
high=max_tokens_per_expert,
size=(num_experts, ),
device="cuda",
dtype=torch.int32)
A, A_q, A_scale = make_quantized_test_activations(
num_experts,
max_tokens_per_expert,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant)
B, B_q, B_scale, _, _, _ = make_test_weights(
num_experts,
N // 2,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
)
out_shape = (num_experts, max_tokens_per_expert, N)
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
compute_tl_dtype = { compute_tl_dtype = {
torch.float16: tl.float16, torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16, torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32 torch.float32: tl.float32
}[test_output.dtype] }[test_output.dtype]
assert A_q.dtype == B_q.dtype
invoke_moe_batched_triton_kernel( invoke_moe_batched_triton_kernel(
tensors.A, A_q,
tensors.B, B_q,
test_output, test_output,
tensors.num_expert_tokens, num_expert_tokens,
compute_tl_dtype, compute_tl_dtype,
# Quantization data # Quantization data
None, A_scale,
None, B_scale,
None, None,
# Quantization schemes # Quantization schemes
False, use_fp8_w8a8,
False, False,
False, False,
config={ config={
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16, "BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16 "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
}) },
block_shape=block_shape,
)
ref_output = ref_impl(tensors.A, tensors.B, ref_output, ref_output = native_batched_masked_quant_matmul(
tensors.num_expert_tokens) A,
B,
ref_output,
num_expert_tokens,
None,
None,
None,
)
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape)
rtol, atol = { rtol, atol = {
torch.float16: (6e-2, 6e-2), torch.float16: (6e-2, 6e-2),
...@@ -112,4 +201,98 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ...@@ -112,4 +201,98 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch.float32: (1e-2, 1e-2), torch.float32: (1e-2, 1e-2),
}[test_output.dtype] }[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("block_shape", [None])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None or topk > e:
pytest.skip("Skip illegal quantization test.")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype)
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
batched_output = batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
baseline_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
triton_output = triton_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
torch.testing.assert_close(triton_output,
baseline_output,
atol=2e-2,
rtol=2e-2)
torch.testing.assert_close(triton_output,
batched_output,
atol=2e-2,
rtol=2e-2)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4608, 128),
(1, 4608, 512),
(1, 4608, 7168),
(83, 128, 128),
(83, 512, 512),
(83, 1024, 7168),
(83, 4608, 512),
(83, 4608, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4608, 512),
(128, 4608, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4608, 512),
(2048, 4608, 7168),
(8192, 128, 128),
(8192, 512, 512),
(8192, 128, 7168),
(8192, 1024, 7168),
(8192, 4608, 512),
(8192, 4608, 7168),
]
MNK_FACTORS_DG = [
(128, 128, 128),
(128, 512, 512),
(128, 128, 7168),
(128, 1024, 7168),
(128, 4608, 128),
(128, 4608, 512),
(128, 4608, 7168),
(192, 128, 128),
(192, 512, 512),
(192, 1024, 7168),
(192, 4608, 512),
(192, 4608, 7168),
(1335, 128, 128),
(1335, 1024, 7168),
(1335, 4608, 512),
(1335, 4608, 7168),
(2048, 128, 128),
(2048, 512, 512),
(2048, 128, 7168),
(2048, 1024, 7168),
(2048, 4608, 128),
(2048, 4608, 512),
(2048, 4608, 7168),
]
BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16] # [128, 256]
TOP_KS = [1, 2, 6]
SEEDS = [0]
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.size(1)
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")
@pytest.fixture(autouse=True)
def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=False,
block_shape=block_size)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(
a,
w1,
w2,
w1_s,
w2_s,
topk_weights,
topk_ids,
block_size,
)
out = fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
m_out = m_fused_moe(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol = 0.035 if M < 40000 else 0.039
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
chunk_size = 1024
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids, block_size)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16]
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4096, 128),
(1, 4096, 512),
(1, 4096, 7168),
(33, 128, 128),
(33, 512, 512),
(33, 128, 7168),
(33, 1024, 7168),
(33, 4096, 128),
(33, 4096, 512),
(33, 4096, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4096, 512),
(128, 4096, 7168),
(222, 128, 128),
(222, 512, 512),
(222, 1024, 7168),
(222, 4096, 512),
(222, 4096, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4096, 512),
(2048, 4096, 7168),
]
E = [8, 24]
TOP_KS = [2, 6]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
# For test
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using
native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
def setup_cuda():
"""Sets the default CUDA device for all tests in this module."""
torch.set_default_device("cuda")
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch.manual_seed(seed)
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.int8,
per_act_token_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
# Check results
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
...@@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors): ...@@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors):
n_b_scales = 2 * n if per_out_channel else 1 n_b_scales = 2 * n if per_out_channel else 1
k_b_scales = k if per_out_channel else 1 k_b_scales = k if per_out_channel else 1
# Get the right scale for tests. # Get the right scale for tests.
_, a_scale = ops.scaled_fp8_quant( a_q, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token) moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
a_scale,
use_per_token_if_dynamic=per_act_token)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
...@@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, ...@@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def run_8_bit(moe_tensors: MOETensors8Bit, def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
per_act_token: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor: num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([ assert not any([
t is None for t in [ t is None for t in [
...@@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ...@@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids, 'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale, 'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale, 'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale 'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
} }
num_experts = moe_tensors.w1.size(0) num_experts = moe_tensors.w1.size(0)
...@@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph( ...@@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids) topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids) cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch.testing.assert_close(triton_output, torch.testing.assert_close(triton_output,
cutlass_output, cutlass_output,
atol=5e-2, atol=5.5e-2,
rtol=1e-2) rtol=1e-2)
...@@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph( ...@@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream): with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids) cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token)
torch.cuda.synchronize() torch.cuda.synchronize()
graph.replay() graph.replay()
...@@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP( ...@@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
cutlass_output = run_8_bit(mt, cutlass_output = run_8_bit(mt,
topk_weights, topk_weights,
topk_ids, topk_ids,
per_act_token,
num_local_experts=e // ep_size) num_local_experts=e // ep_size)
torch.testing.assert_close(triton_output, torch.testing.assert_close(triton_output,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Test DeepEP + DeepGEMM integration Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case. fp8 block-quantized case.
""" """
...@@ -17,12 +17,11 @@ from vllm.config import VllmConfig, set_current_vllm_config ...@@ -17,12 +17,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm from vllm.utils import has_deep_ep, has_deep_gemm
from .utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
if has_deep_ep(): if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
...@@ -30,10 +29,9 @@ if has_deep_ep(): ...@@ -30,10 +29,9 @@ if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm(): if has_deep_gemm():
import deep_gemm
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
...@@ -60,25 +58,6 @@ def next_power_of_2(x): ...@@ -60,25 +58,6 @@ def next_power_of_2(x):
return 2**math.ceil(math.log2(x)) return 2**math.ceil(math.log2(x))
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights( def make_block_quant_fp8_weights(
e: int, e: int,
n: int, n: int,
...@@ -86,43 +65,11 @@ def make_block_quant_fp8_weights( ...@@ -86,43 +65,11 @@ def make_block_quant_fp8_weights(
block_size: list[int], block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale Return weights w1q, w2q, w1_scale, w2_scale
""" """
dtype = torch.bfloat16 w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
fp8_info = torch.finfo(torch.float8_e4m3fn) return w1q, w2q, w1_scale, w2_scale
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device="cuda",
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device="cuda",
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
return w1, w2, w1_s, w2_s
@dataclasses.dataclass @dataclasses.dataclass
...@@ -132,6 +79,7 @@ class TestConfig: ...@@ -132,6 +79,7 @@ class TestConfig:
k: int k: int
n: int n: int
num_experts: int num_experts: int
per_act_token_quant: bool
block_size: list[int] block_size: list[int]
# configs for testing low-latency kernels # configs for testing low-latency kernels
low_latency: bool low_latency: bool
...@@ -150,8 +98,7 @@ class TestTensors: ...@@ -150,8 +98,7 @@ class TestTensors:
def make(config: TestConfig, rank) -> "TestTensors": def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16 dtype = torch.bfloat16
topk, m, k, block_size = (config.topk, config.m, config.k, topk, m, k = (config.topk, config.m, config.k)
config.block_size)
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
...@@ -159,9 +106,7 @@ class TestTensors: ...@@ -159,9 +106,7 @@ class TestTensors:
rank_tokens = torch.randn( rank_tokens = torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 (m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max) rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
rank_token_scales = None
block_k = block_size[1]
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
topk_ids = torch.randint( topk_ids = torch.randint(
low=0, low=0,
...@@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype=q_dtype, q_dtype=q_dtype,
block_shape=test_config.block_size) block_shape=test_config.block_size)
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank, fused_experts = BatchedDeepGemmExperts(
world_size=pgi.world_size, max_num_tokens=max_tokens_per_rank,
dp_size=dp_size, world_size=pgi.world_size,
block_shape=test_config.block_size) dp_size=dp_size,
block_shape=test_config.block_size,
per_act_token_quant=test_config.per_act_token_quant)
mk = FusedMoEModularKernel(prepare_finalize=a2a, mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts) fused_experts=fused_experts)
return mk return mk
...@@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, ...@@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
""" """
Tests for High-Throughput DeepEP + DeepGemm integration. Tests for High-Throughput DeepEP + DeepGemm integration.
""" """
import deep_gemm
m, n, k = mnk m, n, k = mnk
current_platform.seed_everything(7) current_platform.seed_everything(7)
...@@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, ...@@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
k=k, k=k,
n=n, n=n,
num_experts=num_experts, num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size, block_size=block_size,
low_latency=False, low_latency=False,
use_fp8_dispatch=None) use_fp8_dispatch=None)
...@@ -474,10 +423,14 @@ USE_FP8_DISPATCH = [False] ...@@ -474,10 +423,14 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int, def test_ll_deepep_deepgemm_moe(
int], num_experts: int, topk: int, mnk: tuple[int, int, int],
use_fp8_dispatch: bool, block_size: list[int], num_experts: int,
world_dp_size: tuple[int, int]): topk: int,
use_fp8_dispatch: bool,
block_size: list[int],
world_dp_size: tuple[int, int],
):
""" """
Tests for Low-Latency DeepEP + DeepGemm integration. Tests for Low-Latency DeepEP + DeepGemm integration.
""" """
...@@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int, ...@@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
k=k, k=k,
n=n, n=n,
num_experts=num_experts, num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size, block_size=block_size,
low_latency=True, low_latency=True,
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep from vllm.utils import has_deep_ep
from .utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
if has_deep_ep(): if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
...@@ -31,7 +31,7 @@ if has_deep_ep(): ...@@ -31,7 +31,7 @@ if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif( requires_deep_ep = pytest.mark.skipif(
not has_deep_ep(), not has_deep_ep(),
...@@ -102,10 +102,6 @@ class TestTensors: ...@@ -102,10 +102,6 @@ class TestTensors:
rank_tokens = torch.randn( rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10 (config.m, config.k), device="cuda", dtype=token_dtype) / 10
rank_token_scales = None rank_token_scales = None
if config.dtype == torch.float8_e4m3fn:
# low_latency_mode kernels dont support per-token quant.
_, rank_token_scales = ops.scaled_fp8_quant(
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
topk = torch.randint(low=0, topk = torch.randint(low=0,
high=config.num_experts, high=config.num_experts,
...@@ -121,11 +117,18 @@ class TestTensors: ...@@ -121,11 +117,18 @@ class TestTensors:
config=config) config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, def make_modular_kernel(
low_latency_mode: bool, hidden_size: int, dp_size: int, pg: ProcessGroup,
num_experts: int, num_local_experts: int, pgi: ProcessGroupInfo,
q_dtype: Optional[torch.dtype], low_latency_mode: bool,
use_fp8_dispatch: bool) -> FusedMoEModularKernel: hidden_size: int,
dp_size: int,
num_experts: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None is_quantized = q_dtype is not None
...@@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
deepep_ll_args = ll_args) deepep_ll_args = ll_args)
if low_latency_mode: if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts( fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK, max_num_tokens=MAX_TOKENS_PER_RANK,
world_size=pgi.world_size, world_size=pgi.world_size,
...@@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
use_fp8_w8a8=is_quantized, use_fp8_w8a8=is_quantized,
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False) use_int4_w4a16=False,
per_act_token_quant=False,
)
else: else:
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized, fused_experts = TritonExperts(
use_int8_w8a8=False, use_fp8_w8a8=is_quantized,
use_int8_w8a16=False, use_int8_w8a8=False,
use_int4_w4a16=False, use_int8_w8a16=False,
per_channel_quant=False) use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts) fused_experts=fused_experts)
return mk return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, def deep_ep_moe_impl(
low_latency_mode: bool, dp_size: int, pg: ProcessGroup,
test_tensors: TestTensors, w1: torch.Tensor, pgi: ProcessGroupInfo,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor], low_latency_mode: bool,
w2_scale: Optional[torch.Tensor], num_experts: int, dp_size: int,
use_fp8_dispatch: bool) -> torch.Tensor: test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int,
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> torch.Tensor:
num_local_experts = w1.size(0) num_local_experts = w1.size(0)
...@@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype = torch.float8_e4m3fn q_dtype = torch.float8_e4m3fn
# Make modular kernel # Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode, mk: FusedMoEModularKernel = make_modular_kernel(
hidden_size, dp_size, pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_experts, num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
num_local_experts, q_dtype,
use_fp8_dispatch)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens) out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0) total_num_tokens = test_tensors.rank_tokens.size(0)
...@@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, ...@@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return out_hidden_states return out_hidden_states
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor, def torch_moe_impl(
w2: torch.Tensor, w1_scale: Optional[torch.Tensor], test_tensors: TestTensors,
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool): w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
using_fp8_dispatch: bool,
per_act_token_quant: bool,
):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk, a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights) test_tensors.topk_weights)
...@@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor, ...@@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
# The DeepEP implementation is requested to dispatch using FP8. # The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by # For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant. # blockwise quant and de-quant.
assert not per_act_token_quant
a = test_tensors.rank_tokens a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128) aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view( a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
...@@ -310,6 +331,7 @@ def _deep_ep_moe( ...@@ -310,6 +331,7 @@ def _deep_ep_moe(
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
use_fp8_dispatch: bool, use_fp8_dispatch: bool,
per_act_token_quant: bool,
): ):
if not low_latency_mode: if not low_latency_mode:
...@@ -331,7 +353,8 @@ def _deep_ep_moe( ...@@ -331,7 +353,8 @@ def _deep_ep_moe(
with set_current_vllm_config(VllmConfig()): with set_current_vllm_config(VllmConfig()):
# Reference # Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale, torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch) w2_scale, use_fp8_dispatch,
per_act_token_quant)
# Splice experts for this rank. # Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size num_local_experts = config.num_experts // pgi.world_size
...@@ -356,6 +379,7 @@ def _deep_ep_moe( ...@@ -356,6 +379,7 @@ def _deep_ep_moe(
w2_scale_ep, w2_scale_ep,
config.num_experts, config.num_experts,
use_fp8_dispatch, use_fp8_dispatch,
per_act_token_quant,
) )
torch.testing.assert_close( torch.testing.assert_close(
...@@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] ...@@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@requires_deep_ep @requires_deep_ep
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], def test_deep_ep_moe(
num_experts: int, topk: int, world_dp_size: tuple[int, dtype: torch.dtype,
int]): mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
):
low_latency_mode = False low_latency_mode = False
use_fp8_dispatch = False use_fp8_dispatch = False
m, n, k = mnk m, n, k = mnk
...@@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], ...@@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
per_act_token_quant)
MNKs = [ MNKs = [
...@@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], ...@@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
False)
...@@ -17,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock ...@@ -17,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
...@@ -142,6 +143,10 @@ def test_fused_moe( ...@@ -142,6 +143,10 @@ def test_fused_moe(
# Setup test data # Setup test data
# #
#
# Setup test data
#
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
...@@ -169,7 +174,7 @@ def test_fused_moe( ...@@ -169,7 +174,7 @@ def test_fused_moe(
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
per_channel_quant=False, per_act_token_quant=False,
block_shape=None) block_shape=None)
def m_fused_moe( def m_fused_moe(
...@@ -365,6 +370,13 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -365,6 +370,13 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
if dtype == torch.float32: if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32") pytest.skip("AITER ROCm test skip for float32")
monkeypatch.setenv('RANK', "0")
monkeypatch.setenv('LOCAL_RANK', "0")
monkeypatch.setenv('WORLD_SIZE', "1")
monkeypatch.setenv('MASTER_ADDR', 'localhost')
monkeypatch.setenv('MASTER_PORT', '12345')
init_distributed_environment()
# Instantiate our and huggingface's MoE blocks # Instantiate our and huggingface's MoE blocks
vllm_config.compilation_config.static_forward_context = dict() vllm_config.compilation_config.static_forward_context = dict()
with (set_current_vllm_config(vllm_config), with (set_current_vllm_config(vllm_config),
......
...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk ...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform from vllm.platforms import current_platform
if not current_platform.has_device_capability(100): if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True) allow_module_level=True)
MNK_FACTORS = [ MNK_FACTORS = [
......
...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
try: try:
from pplx_kernels import AllToAll from pplx_kernels import AllToAll
...@@ -93,7 +93,7 @@ def pplx_cutlass_moe( ...@@ -93,7 +93,7 @@ def pplx_cutlass_moe(
num_experts=num_experts, num_experts=num_experts,
experts_per_token=topk, experts_per_token=topk,
rank=rank, rank=rank,
world_size=pgi.world_size, world_size=world_size,
dp_size=dp_size, dp_size=dp_size,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1 hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
...@@ -118,8 +118,6 @@ def pplx_cutlass_moe( ...@@ -118,8 +118,6 @@ def pplx_cutlass_moe(
pgi.world_size, pgi.world_size,
rank, rank,
dp_size, dp_size,
quant_dtype=torch.float8_e4m3fn,
per_act_token=per_act_token,
) )
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
......
...@@ -18,18 +18,20 @@ try: ...@@ -18,18 +18,20 @@ try:
except ImportError: except ImportError:
has_pplx = False has_pplx = False
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
get_default_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up
from .utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif( requires_pplx = pytest.mark.skipif(
not has_pplx, not has_pplx,
...@@ -144,25 +146,6 @@ def torch_batched_moe( ...@@ -144,25 +146,6 @@ def torch_batched_moe(
return torch_finalize(out, topk_weight, topk_ids) return torch_finalize(out, topk_weight, topk_ids)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1,
rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
@pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("k", [128, 512, 1024])
...@@ -188,7 +171,7 @@ def test_fused_moe_batched_experts( ...@@ -188,7 +171,7 @@ def test_fused_moe_batched_experts(
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
torch.testing.assert_close(baseline_output, torch.testing.assert_close(baseline_output,
torch_output, torch_output,
...@@ -226,7 +209,6 @@ def pplx_prepare_finalize( ...@@ -226,7 +209,6 @@ def pplx_prepare_finalize(
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
num_tokens, hidden_dim = a.shape num_tokens, hidden_dim = a.shape
block_size = 128
device = pgi.device device = pgi.device
rank = pgi.rank rank = pgi.rank
world_size = pgi.world_size world_size = pgi.world_size
...@@ -241,9 +223,7 @@ def pplx_prepare_finalize( ...@@ -241,9 +223,7 @@ def pplx_prepare_finalize(
dp_size=dp_size, dp_size=dp_size,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize, hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else hidden_dim_scale_bytes=0,
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
) )
if group_name is None: if group_name is None:
...@@ -260,7 +240,6 @@ def pplx_prepare_finalize( ...@@ -260,7 +240,6 @@ def pplx_prepare_finalize(
world_size, world_size,
rank, rank,
dp_size, dp_size,
a.dtype,
) )
a_chunk = chunk_by_rank(a, rank, world_size).to(device) a_chunk = chunk_by_rank(a, rank, world_size).to(device)
...@@ -276,6 +255,7 @@ def pplx_prepare_finalize( ...@@ -276,6 +255,7 @@ def pplx_prepare_finalize(
num_experts, num_experts,
None, None,
False, False,
FusedMoEQuantConfig(),
) )
b_a = b_a * 1.5 b_a = b_a * 1.5
...@@ -350,6 +330,7 @@ def _pplx_prepare_finalize( ...@@ -350,6 +330,7 @@ def _pplx_prepare_finalize(
# TODO (bnell): this test point does not work for odd M due to how the test is # TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe # written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M. # test below is able to deal with odd M.
# TODO (bnell) add fp8 tests
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) @pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
...@@ -386,18 +367,31 @@ def pplx_moe( ...@@ -386,18 +367,31 @@ def pplx_moe(
w2: torch.Tensor, w2: torch.Tensor,
topk_weight: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
use_compile: bool = False, use_compile: bool = False,
use_cudagraphs: bool = True, use_cudagraphs: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
hidden_dim = a.shape[1] hidden_dim = a.shape[1]
num_experts = w1.shape[0] num_experts = w1.shape[0]
block_size = 128
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size) max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
max_num_tokens,
hidden_dim,
a.dtype,
qtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
args = dict( args = dict(
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
...@@ -407,10 +401,8 @@ def pplx_moe( ...@@ -407,10 +401,8 @@ def pplx_moe(
world_size=world_size, world_size=world_size,
dp_size=dp_size, dp_size=dp_size,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize, hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else hidden_dim_scale_bytes=scale_bytes,
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
) )
if group_name is None: if group_name is None:
...@@ -429,9 +421,11 @@ def pplx_moe( ...@@ -429,9 +421,11 @@ def pplx_moe(
dp_size, dp_size,
) )
experts = BatchedTritonExperts(max_num_tokens=a.shape[0], experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
world_size=world_size, world_size=world_size,
dp_size=dp_size) dp_size=dp_size,
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
block_shape=block_shape)
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
...@@ -447,6 +441,13 @@ def pplx_moe( ...@@ -447,6 +441,13 @@ def pplx_moe(
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
if w1_scale is not None:
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
else:
w1_scale_chunk = None
w2_scale_chunk = None
# Note: for now use_compile will error out if the problem size is # Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and # large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later. # setup code in case we are able to revisit this later.
...@@ -465,6 +466,8 @@ def pplx_moe( ...@@ -465,6 +466,8 @@ def pplx_moe(
w2_chunk, w2_chunk,
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts) global_num_experts=num_experts)
if use_cudagraphs: if use_cudagraphs:
...@@ -477,6 +480,8 @@ def pplx_moe( ...@@ -477,6 +480,8 @@ def pplx_moe(
w2_chunk, w2_chunk,
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts) global_num_experts=num_experts)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): ...@@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
rank=rank, rank=rank,
) )
experts = BatchedExperts(max_num_tokens=a.shape[0], experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
world_size=1, world_size=1,
dp_size=1) dp_size=1)
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
...@@ -539,7 +544,12 @@ def _pplx_moe( ...@@ -539,7 +544,12 @@ def _pplx_moe(
w2: torch.Tensor, w2: torch.Tensor,
score: torch.Tensor, score: torch.Tensor,
topk: int, topk: int,
use_internode: bool, w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_internode: bool = False,
): ):
if use_internode: if use_internode:
uid = nvshmem_get_unique_id( uid = nvshmem_get_unique_id(
...@@ -557,11 +567,28 @@ def _pplx_moe( ...@@ -557,11 +567,28 @@ def _pplx_moe(
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
device = torch.device("cuda", pgi.rank)
a = a.to(device)
w1 = w1.to(device)
w2 = w2.to(device)
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
with set_current_vllm_config(vllm_config), override_config(moe_config): with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids) torch_output = torch_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=qtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
a, w1, w2, topk_weight, topk_ids) a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
qtype, per_act_token_quant, block_shape)
# TODO (bnell): fix + re-enable # TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids) # topk_ids)
...@@ -581,6 +608,8 @@ def _pplx_moe( ...@@ -581,6 +608,8 @@ def _pplx_moe(
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("use_internode", [False]) @pytest.mark.parametrize("use_internode", [False])
@requires_pplx @requires_pplx
def test_pplx_moe( def test_pplx_moe(
...@@ -589,15 +618,33 @@ def test_pplx_moe( ...@@ -589,15 +618,33 @@ def test_pplx_moe(
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
use_internode: bool, use_internode: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
m, n, k = mnk m, n, k = mnk
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 if dtype == torch.float8_e4m3fn:
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 use_fp8_w8a8 = True
score = torch.randn((m, e), device="cuda", dtype=dtype) quant_dtype = dtype
else:
use_fp8_w8a8 = False
quant_dtype = None
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
pytest.skip("Skip quantization test for non-quantized type")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
quant_dtype=quant_dtype,
block_shape=block_shape)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
use_internode) use_internode)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
DeepEP test utilities from typing import Optional
"""
import dataclasses
import importlib
import os
import traceback
from typing import Callable, Optional
import torch import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
## Parallel Processes Utils
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
nprocs=world_size,
join=True,
)
## DeepEP specific utils import vllm._custom_ops as ops
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
per_block_cast_to_int8)
@dataclasses.dataclass from vllm.model_executor.layers.fused_moe import fused_experts
class DeepEPHTArgs: from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
num_local_experts: int BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
@dataclasses.dataclass from vllm.model_executor.layers.fused_moe.utils import (
class DeepEPLLArgs: moe_kernel_quantize_input)
max_tokens_per_rank: int from vllm.utils import round_up
hidden_size: int
num_experts: int
use_fp8_dispatch: bool def triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
def make_deepep_ht_a2a(pg: ProcessGroup, w2: torch.Tensor,
pgi: ProcessGroupInfo, topk_weight: torch.Tensor,
dp_size: int, topk_ids: torch.Tensor,
ht_args: DeepEPHTArgs, w1_scale: Optional[torch.Tensor] = None,
q_dtype: Optional[torch.dtype] = None, w2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None): a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
import deep_ep quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
# high throughput a2a block_shape: Optional[list[int]] = None,
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB ) -> torch.Tensor:
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1 return fused_experts(a,
buffer = deep_ep.Buffer(group=pg, w1,
num_nvl_bytes=num_nvl_bytes, w2,
num_rdma_bytes=num_rdma_bytes, topk_weight,
low_latency_mode=low_latency_mode, topk_ids,
num_qps_per_rank=num_qps_per_rank) w1_scale=w1_scale,
return DeepEPHTPrepareAndFinalize(buffer=buffer, w2_scale=w2_scale,
world_size=pgi.world_size, a1_scale=a1_scale,
rank=pgi.rank, a2_scale=a2_scale,
dp_size=dp_size, per_channel_quant=per_act_token_quant,
rank_expert_offset=pgi.rank * use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
ht_args.num_local_experts, block_shape=block_shape)
quant_dtype=q_dtype,
block_shape=block_shape)
def batched_moe(
a: torch.Tensor,
def make_deepep_ll_a2a(pg: ProcessGroup, w1: torch.Tensor,
pgi: ProcessGroupInfo, w2: torch.Tensor,
dp_size: int, topk_weight: torch.Tensor,
deepep_ll_args: DeepEPLLArgs, topk_ids: torch.Tensor,
q_dtype: Optional[torch.dtype] = None, w1_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None): w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
import deep_ep a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
# low-latency a2a per_act_token_quant: bool = False,
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( block_shape: Optional[list[int]] = None,
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size, ) -> torch.Tensor:
pgi.world_size, deepep_ll_args.num_experts) max_num_tokens = round_up(a.shape[0], 64)
buffer = deep_ep.Buffer(group=pg, fused_experts = FusedMoEModularKernel(
num_rdma_bytes=num_rdma_bytes, BatchedPrepareAndFinalize(max_num_tokens,
low_latency_mode=True, world_size=1,
num_qps_per_rank=deepep_ll_args.num_experts // dp_size=1,
pgi.world_size) rank=0),
BatchedTritonExperts(
return DeepEPLLPrepareAndFinalize( max_num_tokens=max_num_tokens,
buffer=buffer, world_size=1,
world_size=pgi.world_size, dp_size=1,
dp_size=dp_size, use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, per_act_token_quant=per_act_token_quant,
quant_dtype=q_dtype, block_shape=block_shape,
block_shape=block_shape, ),
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
) )
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def naive_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
rank=0),
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
dp_size=1,
world_size=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
def make_deepep_a2a(pg: ProcessGroup, return fused_experts(a,
pgi: ProcessGroupInfo, w1,
dp_size: int, w2,
deepep_ht_args: Optional[DeepEPHTArgs], topk_weight,
deepep_ll_args: Optional[DeepEPLLArgs], topk_ids,
q_dtype: Optional[torch.dtype] = None, w1_scale=w1_scale,
block_shape: Optional[list[int]] = None): w2_scale=w2_scale,
if deepep_ht_args is not None: a1_scale=a1_scale,
assert deepep_ll_args is None a2_scale=a2_scale)
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
assert deepep_ll_args is not None end: int) -> Optional[torch.Tensor]:
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype, if scales is not None:
block_shape) if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
def make_quantized_test_activations(
E: int,
m: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
a_q = a
a_scale = None
if quant_dtype is not None:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
a_q = torch.zeros_like(a, dtype=quant_dtype)
a_scale_l = [None] * E
for e in range(E):
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
a[e], None, quant_dtype, per_act_token_quant, block_shape)
a_scale = torch.stack(a_scale_l)
if not per_act_token_quant and block_shape is None:
a_scale = a_scale.view(E, 1, 1)
return a, a_q, a_scale
def moe_quantize_weights(
w: torch.Tensor,
w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype],
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
if block_shape is not None:
assert not per_token_quant
if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape)
else:
w, w_s = per_block_cast_to_fp8(w, block_shape)
else:
if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
else:
w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
return w, w_s
def make_test_weight(
e: int,
rows: int,
cols: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
if quant_dtype is not None:
w_l = [None] * e
w_s_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
if w_s.ndim == 2:
assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1)
if block_shape is not None:
block_n, block_k = block_shape
n_tiles = (rows + block_n - 1) // block_n
k_tiles = (cols + block_k - 1) // block_k
assert w_s.shape == (e, n_tiles, k_tiles)
else:
w = w_16
w_s = None
return w_16, w, w_s
def make_test_weights(
e: int,
n: int,
k: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
)
...@@ -5,7 +5,10 @@ from typing import Optional, Union ...@@ -5,7 +5,10 @@ from typing import Optional, Union
import torch import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up
# Using the default value (240.0) from pytorch will cause accuracy # Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm. # issue on dynamic quantization models. Here use 224.0 for rocm.
...@@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ...@@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
return ref_out, ref_scale.view((1, )) return ref_out, ref_scale.view((1, ))
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, def native_w8a8_block_matmul(
As: torch.Tensor, Bs: torch.Tensor, block_size, A: torch.Tensor,
output_dtype): B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
compute_type: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise """This function performs matrix multiplication with block-wise
quantization using native torch. quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and It is agnostic to the input data type and can be used for both int8 and
...@@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, ...@@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
`Bs` (float32). `Bs` (float32).
The output is returned in the specified `output_dtype`. The output is returned in the specified `output_dtype`.
""" """
A = A.to(torch.float32) A = A.to(compute_type)
B = B.to(torch.float32) B = B.to(compute_type)
assert A.shape[-1] == B.shape[-1] assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2 assert len(block_size) == 2
...@@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, ...@@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
As = As.reshape(M, As.shape[-1]) As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0] assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}"
assert k_tiles == Bs.shape[1] assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}"
C_shape = (M, N) C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
A_tiles = [ A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
...@@ -152,3 +161,152 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, ...@@ -152,3 +161,152 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
C = C.reshape(origin_C_shape).to(output_dtype) C = C.reshape(origin_C_shape).to(output_dtype)
return C return C
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must "
"be divisible by `group_size`")
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` must be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_min = iinfo.min
int8_max = iinfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
DEFAULT_BLOCK_SHAPE = [128, 128]
def per_block_cast_to_fp8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def per_block_cast_to_int8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
) -> torch.Tensor:
if scale is not None:
f32 = torch.float32
if per_act_token_quant or block_shape is None:
return (t.to(f32) * scale).to(out_dtype)
else:
return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype)
else:
return t.to(out_dtype)
def native_batched_masked_quant_matmul(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
num_expert_tokens: torch.Tensor,
A_scale: Optional[torch.Tensor] = None,
B_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
if A.dtype.itemsize == 1 and block_shape is not None:
assert A_scale is not None and B_scale is not None
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
block_shape, C.dtype)
C[e, :num_tokens, :] = tmp[:num_tokens, :]
elif A.dtype.itemsize == 1 and block_shape is None:
assert A_scale is not None and B_scale is not None
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
C[e, :num_tokens, :] = (
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
else:
assert A_scale is None
assert B_scale is None
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
...@@ -7,16 +7,10 @@ import itertools ...@@ -7,16 +7,10 @@ import itertools
import pytest import pytest
import torch import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
from vllm.config import VllmConfig, set_current_vllm_config native_w8a8_block_matmul,
from vllm.model_executor.layers.activation import SiluAndMul per_block_cast_to_fp8)
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.config import VllmConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -46,78 +40,10 @@ N = [128, 512, 7168, 7748, 13824] ...@@ -46,78 +40,10 @@ N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384] K = [256, 3884, 4096, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168. # and its hidden size is 7168.
M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128]
M_moe_dg = [128, 192, 1335, 2048]
N_moe = [128, 256, 1024, 4608] # [13824]
K_moe = [256, 512, 7168] # [13824]
BLOCK_SIZE = [[128, 128]] BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16, 24] # [128, 256]
TOP_KS = [1, 2, 6]
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0] SEEDS = [0]
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot "
"be divisible by `group_size`")
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
# Skip all tests if CUDA is not available # Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda") pytest.importorskip("torch.cuda")
...@@ -177,111 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -177,111 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001 assert rel_diff < 0.001
@pytest.mark.parametrize(
"M,N,K,E,topk,block_size,dtype,seed",
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES,
SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = (torch.rand(
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
del w1_bf16
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
del w2_bf16
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = torch.rand(
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
w2_s = torch.rand(
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_out = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=E,
w1_scale=w1_s,
w2_scale=w2_s)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
rel_diff = (torch.mean(
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed", "M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
...@@ -324,187 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -324,187 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32)))) torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001 assert rel_diff < 0.001
def fp8_perm(m, idx):
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
M, K = a.shape
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
num_tokens = topk * M
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
a = fp8_perm(a, sorted_token_ids // topk)
if a_s is not None:
a_s = a_s[sorted_token_ids // topk]
return a, a_s, m_indices, inv_perm
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
M = topk_weight.shape[0]
out = out[inv_perm, ...]
tmp_out = out.view(-1, topk, K)
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
block_shape):
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
num_groups = w1.shape[0]
M, K = a.shape
N = w2.shape[-1]
topk_weight, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = per_token_group_quant_fp8(a, block_m)
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
num_groups, topk, block_m)
inter_out = torch.zeros((a_q.shape[0], N * 2),
dtype=torch.bfloat16,
device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
return final_out
@pytest.mark.parametrize(
"M,N,K,E,topk,seed",
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
chunk_size = 1024
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
score = torch.randn((M, E), dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w2 = (N + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(E):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
if M >= 128:
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
score, topk, block_size)
else:
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
...@@ -8,9 +8,7 @@ import pytest ...@@ -8,9 +8,7 @@ import pytest
import torch import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import (
w8a8_block_int8_matmul) w8a8_block_int8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -23,82 +21,10 @@ vllm_config = VllmConfig() ...@@ -23,82 +21,10 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192 vllm_config.scheduler_config.max_model_len = 8192
# For test
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_min = iinfo.min
int8_max = iinfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
# For test
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using
native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
DTYPES = [torch.half, torch.bfloat16] DTYPES = [torch.half, torch.bfloat16]
M = [1, 33, 64, 222] M = [1, 33, 64, 222]
N = [128, 1024] N = [128, 1024]
K = [256, 4096] K = [256, 4096]
E = [8, 24]
TOP_KS = [2, 6]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]] BLOCK_SIZE = [[128, 128]]
SEEDS = [0] SEEDS = [0]
...@@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32)))) torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001 assert rel_diff < 0.001
@pytest.mark.parametrize(
"M, N, K, E, topk, block_size, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
@torch.inference_mode()
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch.manual_seed(seed)
# Use a smaller factor for scale initialization to prevent large
# values/overflow especially when output dtype might be float16
factor_for_scale = 1e-2
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand(
(E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (torch.rand(
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
w2_s = (torch.rand(
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
# Check results
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.06
...@@ -13,8 +13,11 @@ import pytest ...@@ -13,8 +13,11 @@ import pytest
import torch import torch
from torch._prims_common import TensorLikeType from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.platforms.interface import _Backend from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
...@@ -1054,32 +1057,77 @@ def compute_max_diff(output, output_ref): ...@@ -1054,32 +1057,77 @@ def compute_max_diff(output, output_ref):
torch.abs(output_ref)) torch.abs(output_ref))
def torch_experts(a: torch.Tensor, def torch_experts(
w1: torch.Tensor, a: torch.Tensor,
w2: torch.Tensor, w1: torch.Tensor,
topk_weight: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_weight: torch.Tensor,
global_num_experts: int = -1, topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
assert (global_num_experts == -1 assert (global_num_experts == -1
or (global_num_experts == w1.shape[0] and expert_map is None) or (global_num_experts == w1.shape[0] and expert_map is None)
or (expert_map is not None or (expert_map is not None
and global_num_experts == expert_map.shape[0])) and global_num_experts == expert_map.shape[0]))
M, K = a.shape
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1) out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype,
per_act_token_quant, block_shape)
num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1) topk_ids = topk_ids.view(-1)
if expert_map is not None: if expert_map is not None:
topk_ids = expert_map[topk_ids] topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
for i in range(num_experts):
mask = topk_ids == i mask = topk_ids == i
if mask.sum(): if mask.sum():
out[mask] = SiluAndMul()( if quant_dtype is None:
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) tmp1 = a[mask] @ w1[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) * tmp2 = SiluAndMul()(tmp1)
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) out[mask] = tmp2 @ w2[i].transpose(0, 1)
elif block_shape is not None:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
out.dtype)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = moe_kernel_quantize_input(
tmp2, None, quant_dtype, per_act_token_quant, block_shape)
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
out.dtype)
else:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
f32 = torch.float32
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
tmp1 = a[mask].to(f32) * scales
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
tmp1 = tmp1 @ w1_dq
tmp2 = SiluAndMul()(tmp1)
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe(a: torch.Tensor, def torch_moe(a: torch.Tensor,
......
...@@ -1274,7 +1274,7 @@ def scaled_fp8_quant( ...@@ -1274,7 +1274,7 @@ def scaled_fp8_quant(
scale = torch.zeros(1, device=input.device, dtype=torch.float32) scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else: else:
assert scale.numel() == 1 assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale) torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale return output, scale
......
...@@ -4,8 +4,12 @@ ...@@ -4,8 +4,12 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Optional from typing import Any, Optional
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None _config: Optional[dict[str, Any]] = None
...@@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]: ...@@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]:
__all__ = [ __all__ = [
"FusedMoE", "FusedMoE",
"FusedMoEConfig",
"FusedMoEMethodBase", "FusedMoEMethodBase",
"FusedMoeWeightScaleSupported", "FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"override_config", "override_config",
"get_config", "get_config",
] ]
...@@ -36,11 +44,21 @@ if HAS_TRITON: ...@@ -36,11 +44,21 @@ if HAS_TRITON:
# import to register the custom ops # import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4, cutlass_moe_fp8) CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, fused_experts, fused_moe, fused_topk, TritonExperts, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk) get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
__all__ += [ __all__ += [
"fused_moe", "fused_moe",
...@@ -50,5 +68,11 @@ if HAS_TRITON: ...@@ -50,5 +68,11 @@ if HAS_TRITON:
"grouped_topk", "grouped_topk",
"cutlass_moe_fp8", "cutlass_moe_fp8",
"cutlass_moe_fp4", "cutlass_moe_fp4",
"CutlassExpertsFp8",
"TritonExperts", "TritonExperts",
"BatchedTritonExperts",
"DeepGemmExperts",
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"BatchedTritonOrDeepGemmExperts",
] ]
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -179,28 +180,44 @@ def silu_mul_fp8_quant_deep_gemm( ...@@ -179,28 +180,44 @@ def silu_mul_fp8_quant_deep_gemm(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128 # The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE = 128 DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, def __init__(self,
block_shape: list[int]): max_num_tokens: int,
world_size: int,
dp_size: int,
block_shape: list[int],
per_act_token_quant=False):
""" """
max_num_tokens: Maximum number of tokens from a DP Rank max_num_tokens: Maximum number of tokens from a DP Rank
world_size: Number of EP ranks world_size: Number of EP ranks
dp_size: Number of data-parallel ranks dp_size: Number of data-parallel ranks
block_shape: Block quantization block shape block_shape: Block quantization block shape
""" """
super().__init__() super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.world_size = world_size self.world_size = world_size
self.dp_size = dp_size self.dp_size = dp_size
self.block_shape = block_shape
assert (len(self.block_shape) == 2 and all( @property
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape])) def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return False return False
def supports_expert_map(self) -> bool:
return False
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -248,6 +265,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -248,6 +265,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
): ):
import deep_gemm as dg import deep_gemm as dg
assert hidden_states.ndim == 3 assert hidden_states.ndim == 3
assert self.block_shape is not None
a1q = hidden_states a1q = hidden_states
_, N, K = w1.size() _, N, K = w1.size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment