Commit 41199996 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.12.0' into v0.12.0-dev

parents 31021d81 4fd9d6a8
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from typing import Callable from collections.abc import Callable
from unittest.mock import patch from unittest.mock import patch
import pandas as pd import pandas as pd
...@@ -10,7 +10,8 @@ import torch ...@@ -10,7 +10,8 @@ import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
def with_triton_mode(fn): def with_triton_mode(fn):
......
...@@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401 ...@@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
batch_size_range = [1, 16, 32, 64, 128] batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
......
...@@ -28,7 +28,7 @@ except ImportError as e: ...@@ -28,7 +28,7 @@ except ImportError as e:
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark BitBLAS int4 on a specific target." description="Benchmark BitBLAS int4 on a specific target."
......
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
WEIGHT_SHAPES_MOE = { WEIGHT_SHAPES_MOE = {
"nvidia/DeepSeek-R1-FP4": [ "nvidia/DeepSeek-R1-FP4": [
......
...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_confi ...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_confi
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
# Weight shapes for different models: [num_experts, topk, hidden_size, # Weight shapes for different models: [num_experts, topk, hidden_size,
# intermediate_size] # intermediate_size]
...@@ -255,8 +255,8 @@ def bench_run( ...@@ -255,8 +255,8 @@ def bench_run(
torch.cuda.synchronize() torch.cuda.synchronize()
# Timing # Timing
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies = [] latencies = []
for _ in range(num_iters): for _ in range(num_iters):
......
...@@ -22,8 +22,8 @@ Example: ...@@ -22,8 +22,8 @@ Example:
import json import json
import os import os
import time import time
from collections.abc import Callable
from contextlib import nullcontext from contextlib import nullcontext
from typing import Callable, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -39,7 +39,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import ( ...@@ -39,7 +39,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
) )
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -264,12 +264,12 @@ class CommunicatorBenchmark: ...@@ -264,12 +264,12 @@ class CommunicatorBenchmark:
def benchmark_allreduce_single( def benchmark_allreduce_single(
self, self,
sequence_length: int, sequence_length: int,
allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]], allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None],
should_use_fn: Callable[[torch.Tensor], bool], should_use_fn: Callable[[torch.Tensor], bool],
context, context,
num_warmup: int, num_warmup: int,
num_trials: int, num_trials: int,
) -> Optional[float]: ) -> float | None:
"""Benchmark method with CUDA graph optimization.""" """Benchmark method with CUDA graph optimization."""
try: try:
# Create test tensor (2D: sequence_length x hidden_size) # Create test tensor (2D: sequence_length x hidden_size)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark_fused_collective.py
"""
import argparse
import itertools
import os
import time
import pandas as pd
import torch # type: ignore
import torch.distributed as dist # type: ignore
from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import (
get_tp_group,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.parallel_state import (
graph_capture,
init_distributed_environment,
initialize_model_parallel,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm # noqa
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa
from vllm.platforms import current_platform # noqa
RMS_NORM_OP = torch.ops._C.rms_norm
FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm
RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = (
torch.ops._C.fused_add_rms_norm_static_fp8_quant
)
SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant
logger = init_logger(__name__)
# Try to import FlashInfer
try:
import flashinfer.comm as flashinfer_comm # type: ignore
if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"):
flashinfer_comm = None
logger.warning(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
except ImportError:
flashinfer_comm = None
logger.warning("FlashInfer not found, only benchmarking standard operations")
# Constants
FP8_DTYPE = current_platform.fp8_dtype()
MiB = 1024 * 1024
# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES = {
2: 64 * MiB, # 64MB
4: 64 * MiB, # 64MB
8: 64 * MiB, # 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR = None
def setup_flashinfer_workspace(
world_size: int,
rank: int,
hidden_dim: int,
max_token_num: int,
use_fp32_lamport: bool = False,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global _FI_WORKSPACE_TENSOR
if flashinfer_comm is None:
return None, None
if world_size not in _FI_MAX_SIZES:
logger.warning("FlashInfer not supported for world size %s", world_size)
return None, None
try:
# Create IPC workspace
ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=world_size,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
group=get_tp_group().device_group,
use_fp32_lamport=use_fp32_lamport,
)
)
_FI_WORKSPACE_TENSOR = workspace_tensor
return ipc_handles, workspace_tensor
except Exception as e:
logger.error("Failed to setup FlashInfer workspace: %s", e)
return None, None
def cleanup_flashinfer_workspace(ipc_handles):
"""Cleanup FlashInfer workspace."""
if flashinfer_comm is None or ipc_handles is None:
return
try:
group = get_tp_group().device_group
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)
except Exception as e:
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
class FlashInferFusedAllReduceParams:
"""Parameters for FlashInfer fused allreduce operations."""
def __init__(
self,
rank: int,
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
):
self.rank = rank
self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
self.trigger_completion_at_end = True
self.launch_with_pdl = True
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self):
return {
"world_rank": self.rank,
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
}
def flashinfer_fused_allreduce_rmsnorm(
input_tensor: torch.Tensor,
residual: torch.Tensor | None,
rms_gamma: torch.Tensor,
rms_eps: float,
allreduce_params: "FlashInferFusedAllReduceParams",
use_oneshot: bool,
norm_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
allreduce_out=None,
quant_out=None,
scale_out=None,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=None,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
input_tensor: torch.Tensor,
residual: torch.Tensor | None,
rms_gamma: torch.Tensor,
rms_eps: float,
scale_factor: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
use_oneshot: bool = True,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=None,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=scale_factor,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
input_tensor: torch.Tensor,
residual: torch.Tensor | None,
rms_gamma: torch.Tensor,
rms_eps: float,
input_global_scale: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
quant_out: torch.Tensor,
use_oneshot: bool,
output_scale: torch.Tensor,
norm_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=output_scale,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=input_global_scale,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
class VllmFusedAllreduce:
def __init__(self, hidden_dim, dtype):
self.rms_eps = 1e-6
self.rms_norm = RMSNorm(hidden_dim, eps=self.rms_eps, dtype=dtype)
self.fp8_quant = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
)
def allreduce_rmsnorm(
self, input_tensor: torch.Tensor, residual: torch.Tensor | None
):
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
return self.rms_norm(allreduce_out, residual)
def allreduce_rmsnorm_fp8_quant(
self,
input_tensor: torch.Tensor,
residual: torch.Tensor | None,
scale_factor: torch.Tensor,
):
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
rms_out = self.rms_norm(allreduce_out, residual)
if residual is None:
quant_out = self.fp8_quant(rms_out, scale_factor)
return quant_out
else:
rms_out, residual_out = rms_out
quant_out = self.fp8_quant(rms_out, scale_factor)
return quant_out, residual_out
def allreduce_rmsnorm_fp4_quant(
self,
input_tensor: torch.Tensor,
residual: torch.Tensor | None,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
):
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
rms_out = self.rms_norm(allreduce_out, residual)
if residual is None:
SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale)
return quant_out, output_scale
else:
rms_out, residual_out = rms_out
SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale)
return quant_out, residual_out, output_scale
def create_test_tensors(
num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True
):
"""Create test tensors for benchmarking."""
input_tensor = torch.randn(num_tokens, hidden_dim, dtype=dtype)
residual = (
torch.randn_like(input_tensor)
if use_residual
else torch.zeros_like(input_tensor)
)
rms_gamma = torch.ones(hidden_dim, dtype=dtype)
norm_out = None if use_residual else torch.empty_like(input_tensor)
# Quantization scales
scale_fp8 = torch.tensor(1.0, dtype=torch.float32)
scale_fp4 = torch.tensor(1.0, dtype=torch.float32)
quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8)
fp4_output_scale = torch.empty((128, 4), dtype=torch.int32)
return (
input_tensor,
norm_out,
residual,
rms_gamma,
scale_fp8,
quant_out_fp8,
scale_fp4,
fp4_quant_out,
fp4_output_scale,
)
def benchmark_operation(
operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs
):
"""Benchmark a single operation using CUDA graphs."""
# Warmup before graph capture
for _ in range(warmup):
operation_func(*args, **kwargs)
torch.cuda.synchronize()
# Create CUDA graph
graph = torch.cuda.CUDAGraph()
num_op_per_cudagraph = 10
# Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe
device = torch.device(f"cuda:{torch.cuda.current_device()}")
with graph_capture(device=device), torch.cuda.graph(graph):
for _ in range(num_op_per_cudagraph):
operation_func(*args, **kwargs)
# Graph warmup
torch.cuda.synchronize()
for _ in range(warmup):
graph.replay()
# Benchmark with CUDA graph
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(trials // num_op_per_cudagraph):
# operation_func(*args, **kwargs)
graph.replay()
torch.cuda.synchronize()
end_time = time.perf_counter()
avg_time_ms = ((end_time - start_time) / trials) * 1000
return avg_time_ms
def run_benchmarks(
num_tokens: int,
hidden_dim: int,
dtype: torch.dtype,
use_residual: bool,
allreduce_params: FlashInferFusedAllReduceParams | None,
quant_modes: set[str],
no_oneshot: bool,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
"""
(
input_tensor,
norm_out,
residual,
rms_gamma,
scale_fp8,
quant_out_fp8,
scale_fp4,
fp4_quant_out,
fp4_output_scale,
) = create_test_tensors(num_tokens, hidden_dim, dtype, use_residual)
rms_eps = 1e-6
results = {}
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
use_oneshot_options = [False] if no_oneshot else [True, False]
# Create RMSNorm and QuantFP8 layers once for native benchmarks
if "none" in quant_modes:
# Standard AllReduce + RMSNorm
for custom_op in ["-rms_norm", "+rms_norm"]:
with set_current_vllm_config(
VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op]))
):
try:
suffix = (
"_custom_rms_norm" if "+" in custom_op else "_native_rms_norm"
)
time_ms = benchmark_operation(
vllm_fused_allreduce.allreduce_rmsnorm,
input_tensor,
residual=residual,
)
results[f"standard_allreduce_{suffix}"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm failed: %s", e)
results[f"standard_allreduce_{suffix}"] = float("inf")
# Standard AllReduce + RMSNorm Native Compiled
with set_current_vllm_config(
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
):
try:
standard_allreduce_rmsnorm_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm,
fullgraph=True,
dynamic=False,
)
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_native_compiled,
input_tensor,
residual=residual,
)
results["standard_allreduce_rmsnorm_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
if flashinfer_comm is not None and allreduce_params is not None:
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms
except Exception as e:
logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e)
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float(
"inf"
)
if "fp8" in quant_modes:
# Standard AllReduce + RMSNorm + FP8 Quant
for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]:
suffix = (
"_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
)
for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]:
suffix += (
"_custom_quant_fp8"
if "+" in quant_fp8_custom_op
else "_native_quant_fp8"
)
with set_current_vllm_config(
VllmConfig(
compilation_config=CompilationConfig(
custom_ops=[rms_norm_custom_op, quant_fp8_custom_op]
)
)
):
try:
time_ms = benchmark_operation(
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
input_tensor,
residual=residual,
scale_factor=scale_fp8,
)
results[f"standard_allreduce{suffix}"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
results[f"standard_allreduce{suffix}"] = float("inf")
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
with set_current_vllm_config(
VllmConfig(
compilation_config=CompilationConfig(
custom_ops=["-rms_norm", "-quant_fp8"]
)
)
):
try:
standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
fullgraph=True,
dynamic=False,
)
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp8_quant_native_compiled,
input_tensor,
residual=residual,
scale_factor=scale_fp8,
)
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = (
time_ms
)
except Exception as e:
logger.error(
"Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e
)
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
e,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
float("inf")
)
if "fp4" in quant_modes and current_platform.has_device_capability(100):
# Standard AllReduce + RMSNorm + FP4 Quant
for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]:
suffix = (
"_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
)
with set_current_vllm_config(
VllmConfig(
compilation_config=CompilationConfig(
custom_ops=[rms_norm_custom_op]
)
)
):
try:
time_ms = benchmark_operation(
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
input_global_scale=scale_fp4,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
)
results[f"standard_allreduce_{suffix}_fp4_quant"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e)
results[f"standard_allreduce_{suffix}_fp4_quant"] = float("inf")
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
with set_current_vllm_config(
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
):
try:
standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
fullgraph=True,
dynamic=False,
)
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp4_quant_native_compiled,
input_tensor,
residual=residual,
quant_out=fp4_quant_out,
input_global_scale=scale_fp4,
output_scale=fp4_output_scale,
)
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = (
time_ms
)
except Exception as e:
logger.error(
"Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e
)
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
e,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
float("inf")
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if flashinfer_comm is not None and allreduce_params is not None:
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
"inf"
)
return results
def prepare_results_with_speedups(results_dict):
"""Prepare results with speedup calculations based on dynamic baseline selection."""
prepared_results = []
# Determine the fastest baseline for each operation type
def get_fastest_baseline(op_name, results_dict):
"""Get the fastest baseline between standard and native_compiled versions."""
if "fp8_quant" in op_name:
candidates = [
"standard_allreduce_rmsnorm_fp8_quant",
"standard_allreduce_rmsnorm_fp8_quant_native_compiled",
]
elif "fp4_quant" in op_name:
candidates = [
"standard_allreduce_rmsnorm_fp4_quant",
"standard_allreduce_rmsnorm_fp4_quant_native_compiled",
]
else:
candidates = [
"standard_allreduce_rmsnorm",
"standard_allreduce_rmsnorm_native_compiled",
]
# Find the fastest among available candidates
fastest_time = float("inf")
fastest_baseline = None
for candidate in candidates:
if (
candidate in results_dict
and results_dict[candidate] != float("inf")
and results_dict[candidate] < fastest_time
):
fastest_time = results_dict[candidate]
fastest_baseline = candidate
return fastest_baseline
# Create dynamic baseline mapping
dynamic_baseline_mapping = {}
for op_name in results_dict:
if (
op_name.startswith("flashinfer_")
or op_name.startswith("standard_")
and not op_name.endswith("_native_compiled")
):
dynamic_baseline_mapping[op_name] = get_fastest_baseline(
op_name, results_dict
)
for op_name, time_ms in results_dict.items():
if time_ms == float("inf"):
speedup_str = "FAILED"
time_str = "FAILED"
else:
time_str = f"{time_ms:.3f}"
# Find the appropriate baseline for this operation
baseline_op = dynamic_baseline_mapping.get(op_name)
if baseline_op and baseline_op in results_dict:
baseline_time = results_dict[baseline_op]
if baseline_time != float("inf") and baseline_time > 0:
speedup = baseline_time / time_ms
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
else:
# For baseline operations, determine if this is the fastest baseline
if op_name.endswith("_native_compiled") or (
op_name.startswith("standard_")
and not op_name.endswith("_native_compiled")
):
fastest_baseline = get_fastest_baseline(op_name, results_dict)
if fastest_baseline == op_name:
speedup_str = "baseline"
else:
if fastest_baseline and fastest_baseline in results_dict:
baseline_time = results_dict[fastest_baseline]
if baseline_time != float("inf") and baseline_time > 0:
speedup = baseline_time / time_ms
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
else:
speedup_str = "N/A"
else:
speedup_str = "N/A"
prepared_results.append(
{
"operation": op_name,
"time_ms": time_ms,
"time_str": time_str,
"speedup_str": speedup_str,
}
)
return prepared_results
def print_results(
results_dict,
num_tokens,
hidden_dim,
dtype,
use_residual,
quant_modes,
input_size_mb,
):
"""Print benchmark results in a formatted table."""
print(f"\n{'=' * 80}")
print(
f"Results: num_tokens={num_tokens}, hidden_dim={hidden_dim} "
f"(input size: {input_size_mb:.2f} MB)"
)
print(
f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, "
f"quant_modes={','.join(sorted(list(quant_modes)))}"
)
print(f"{'=' * 80}")
print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}")
print(f"{'-' * 80}")
# Prepare results with speedup calculations
prepared_results = prepare_results_with_speedups(results_dict)
for result in prepared_results:
if result["time_ms"] == float("inf"):
time_display = result["time_str"]
else:
time_display = f"{result['time_ms']:.3f}"
print(
f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}"
)
def format_results_markdown(
all_results: list[dict], world_size: int, args: argparse.Namespace
) -> str:
"""Format all benchmark results as markdown."""
lines: list[str] = []
lines.append("# FlashInfer Fused Collective Operations Benchmark Results")
lines.append("")
lines.append(f"**World Size:** {world_size} ")
lines.append(f"**Hidden Dimension:** {args.hidden_dim} ")
lines.append(f"**Warmup Iterations:** {args.warmup} ")
lines.append(f"**Benchmark Trials:** {args.trials} ")
modes = ",".join(all_results[0]["quant_modes"]) if all_results else "N/A"
lines.append(f"**Quantization Modes:** {modes} ")
lines.append("")
lines.append("---")
lines.append("")
for entry in all_results:
num_tokens = entry["num_tokens"]
dtype = entry["dtype"]
use_residual = entry["use_residual"]
results_dict = entry["results"]
input_size_mb = entry["input_size_mb"]
residual_str = "with residual" if use_residual else "no residual"
lines.append(
f"## Configuration: num_tokens={num_tokens}, dtype={dtype}, {residual_str}"
)
lines.append(f"**Input Size:** {input_size_mb:.2f} MB")
lines.append("")
prepared = prepare_results_with_speedups(results_dict)
# Build DataFrame for markdown export
rows = [
{
"Operation": r["operation"].replace("_", " ").title(),
"Time (ms)": r["time_str"],
"Speedup": r["speedup_str"],
}
for r in prepared
]
df = pd.DataFrame(rows)
if df.empty:
lines.append("No results.")
else:
lines.append(df.to_markdown(index=False))
lines.append("")
return "\n".join(lines)
def save_results_to_file(
all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int
):
"""Save benchmark results to markdown file (only on rank 0)."""
if rank != 0:
return
if not all_results:
logger.warning("No results to save")
return
output_path = args.output_file
try:
markdown_content = format_results_markdown(all_results, world_size, args)
with open(output_path, "a") as f:
f.write(markdown_content)
except Exception as e:
logger.error("Failed to save results to file: %s", e)
def main():
parser = argparse.ArgumentParser(
description="Benchmark fused collective operations"
)
parser.add_argument(
"--num-tokens",
type=int,
nargs="+",
default=[128, 512, 1024, 2048],
help="Numbers of tokens to test",
)
parser.add_argument(
"--hidden-dim", type=int, default=8192, help="Hidden dimension size"
)
parser.add_argument(
"--dtypes",
type=str,
nargs="+",
default=["bfloat16"],
choices=["float16", "bfloat16", "float32"],
help="Data types to test",
)
parser.add_argument(
"--no-residual",
action="store_true",
help="Skip residual connection tests",
)
parser.add_argument(
"--quant-modes",
type=str,
default="none,fp8,fp4",
help=(
"Comma-separated quantization modes to run: none, fp8, fp4. "
"Default: none,fp8,fp4"
),
)
parser.add_argument(
"--warmup", type=int, default=5, help="Number of warmup iterations"
)
parser.add_argument(
"--trials", type=int, default=20, help="Number of benchmark trials"
)
parser.add_argument(
"--output-file",
type=str,
help="""Output file path for markdown results
(default: benchmark_results_<timestamp>.md)
""",
)
parser.add_argument(
"--no-oneshot",
action="store_true",
help="Skip oneshot benchmarks",
)
args = parser.parse_args()
# Check if running with torchrun (required for collective operations)
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
raise RuntimeError(
"Must run with torchrun for distributed benchmarking. "
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
)
# Initialize distributed environment
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Validate world size (must be > 1 for collective operations)
if world_size <= 1:
raise ValueError(
"World size must be > 1 for collective operations benchmarking. "
f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1."
)
# Parse quantization modes
valid_quant_modes = {"none", "fp8", "fp4"}
raw_modes = [
m.strip().lower() for m in (args.quant_modes or "").split(",") if m.strip()
]
quant_modes = set(raw_modes) if raw_modes else {"none", "fp8", "fp4"}
invalid = sorted(list(quant_modes - valid_quant_modes))
if invalid:
raise ValueError(
f"Invalid --quant-modes entries: {','.join(invalid)}. "
f"Valid options are: {','.join(sorted(valid_quant_modes))}."
)
if rank == 0:
logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank)
logger.info("Quantization modes: %s", ",".join(sorted(list(quant_modes))))
if flashinfer_comm is not None:
logger.info(
"FlashInfer available - will benchmark fused operations",
)
else:
logger.info(
"FlashInfer not available - only benchmarking standard operations"
)
# Convert dtype strings to torch dtypes
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtypes = [dtype_map[dt] for dt in args.dtypes]
# Test configurations
residual_options = [True] if not args.no_residual else [False]
configs = list(itertools.product(args.num_tokens, dtypes, residual_options))
# Setup FlashInfer workspace if available
ipc_handles = None
allreduce_params = None
if flashinfer_comm is not None:
# Use the largest hidden dimension for workspace setup
max_num_token = _FI_MAX_SIZES.get(world_size) // (
args.hidden_dim * world_size * 2
)
ipc_handles, workspace_tensor = setup_flashinfer_workspace(
world_size, rank, args.hidden_dim, max_num_token
)
if workspace_tensor is not None:
allreduce_params = FlashInferFusedAllReduceParams(
rank=rank,
world_size=world_size,
max_token_num=max_num_token,
)
# Collect all results for markdown export
all_results = []
try:
# Run benchmarks
for num_tokens, dtype, use_residual in configs:
if rank == 0:
logger.info(
"\nTesting: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s",
num_tokens,
args.hidden_dim,
dtype,
use_residual,
)
results = run_benchmarks(
num_tokens,
args.hidden_dim,
dtype,
use_residual,
allreduce_params,
quant_modes=quant_modes,
no_oneshot=args.no_oneshot,
)
# Store results for markdown export
if rank == 0:
# Calculate input size in MB
input_size_mb = (
num_tokens * args.hidden_dim * torch.finfo(dtype).bits
) / (8 * 1024 * 1024)
all_results.append(
{
"num_tokens": num_tokens,
"hidden_dim": args.hidden_dim,
"dtype": str(dtype).replace("torch.", ""),
"use_residual": use_residual,
"quant_modes": sorted(list(quant_modes)),
"input_size_mb": input_size_mb,
"results": results,
}
)
print_results(
results,
num_tokens,
args.hidden_dim,
dtype,
use_residual,
quant_modes,
input_size_mb,
)
# Save results to markdown file
if args.output_file and rank == 0:
save_results_to_file(all_results, world_size, args, rank)
finally:
# Cleanup
if ipc_handles is not None:
cleanup_flashinfer_workspace(ipc_handles)
dist.barrier()
if __name__ == "__main__":
main()
...@@ -13,11 +13,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -13,11 +13,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_experts,
fused_topk, fused_topk,
) )
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = [ DEFAULT_MODELS = [
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1",
"nm-testing/deepseekv2-lite", "deepseek-ai/DeepSeek-V2-Lite",
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-1b-a400m",
"ibm-granite/granite-3.0-3b-a800m", "ibm-granite/granite-3.0-3b-a800m",
] ]
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode() @torch.inference_mode()
......
...@@ -6,11 +6,12 @@ import copy ...@@ -6,11 +6,12 @@ import copy
import json import json
import pickle import pickle
import time import time
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional from typing import Any
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
...@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement ...@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
from utils import ArgPool, Bench, CudaGraphBenchParams from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES from weight_shapes import WEIGHT_SHAPES
from vllm.triton_utils import HAS_TRITON from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.triton_utils import HAS_TRITON, triton
if HAS_TRITON: if HAS_TRITON:
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora
LoRAKernelMeta,
fused_moe_lora_expand,
fused_moe_lora_shrink,
lora_expand,
lora_shrink,
)
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
_LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora
)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.math_utils import round_up
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_TP_SIZES = [1] DEFAULT_TP_SIZES = [1]
...@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4] ...@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS = [False, True] DEFAULT_SORT_BY_LORA_IDS = [False, True]
DEFAULT_SEQ_LENGTHS = [1] DEFAULT_SEQ_LENGTHS = [1]
DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]
DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k
DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts
# Utilities # Utilities
...@@ -158,7 +172,7 @@ def ref_group_gemm( ...@@ -158,7 +172,7 @@ def ref_group_gemm(
seq_lens_cpu: torch.Tensor, seq_lens_cpu: torch.Tensor,
prompt_lora_mapping_cpu: torch.Tensor, prompt_lora_mapping_cpu: torch.Tensor,
scaling: float, scaling: float,
add_inputs: Optional[bool], add_inputs: bool | None,
): ):
""" """
Torch group gemm reference implementation to test correctness of Torch group gemm reference implementation to test correctness of
...@@ -190,6 +204,11 @@ class OpType(Enum): ...@@ -190,6 +204,11 @@ class OpType(Enum):
LORA_SHRINK = auto() LORA_SHRINK = auto()
LORA_EXPAND = auto() LORA_EXPAND = auto()
## Adding support for fused moe lora
FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink
FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand
FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink
FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand
@staticmethod @staticmethod
def from_str(s: str) -> "OpType": def from_str(s: str) -> "OpType":
...@@ -197,6 +216,15 @@ class OpType(Enum): ...@@ -197,6 +216,15 @@ class OpType(Enum):
return OpType.LORA_SHRINK return OpType.LORA_SHRINK
if s.lower() == "lora_expand": if s.lower() == "lora_expand":
return OpType.LORA_EXPAND return OpType.LORA_EXPAND
# Adding support for fused moe lora, both in gate_up and down
if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink
return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK
if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand
return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND
if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink
return OpType.FUSED_MOE_LORA_DOWN_SHRINK
if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand
return OpType.FUSED_MOE_LORA_DOWN_EXPAND
raise ValueError(f"Unrecognized str {s} to convert to OpType") raise ValueError(f"Unrecognized str {s} to convert to OpType")
def is_shrink_fn(self) -> bool: def is_shrink_fn(self) -> bool:
...@@ -205,19 +233,56 @@ class OpType(Enum): ...@@ -205,19 +233,56 @@ class OpType(Enum):
def is_expand_fn(self) -> bool: def is_expand_fn(self) -> bool:
return self in [OpType.LORA_EXPAND] return self in [OpType.LORA_EXPAND]
def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def is_fused_moe_lora_gate_up_fn(
self,
) -> bool: ## adding for fused MoE LoRA Gate/Up
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
]
def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down
return self in [
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def is_fused_moe_lora_shrink_fn(self) -> bool:
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
]
def is_fused_moe_lora_expand_fn(self) -> bool:
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def num_slices(self) -> list[int]: def num_slices(self) -> list[int]:
if self.is_fused_moe_lora_gate_up_fn():
return [2]
elif self.is_fused_moe_lora_down_fn():
return [1]
return [1, 2, 3] return [1, 2, 3]
def mkn( def mkn(
self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
) -> tuple[int, int, int]: ) -> tuple[int, int, int]:
num_tokens = batch_size * seq_length num_tokens = batch_size * seq_length
if self.is_shrink_fn(): if self.is_shrink_fn() or self.is_fused_moe_lora_fn():
m = num_tokens m = num_tokens
k = hidden_size k = hidden_size
n = lora_rank n = lora_rank
else: elif self.is_expand_fn():
assert self.is_expand_fn()
m = num_tokens m = num_tokens
k = lora_rank k = lora_rank
n = hidden_size n = hidden_size
...@@ -231,9 +296,36 @@ class OpType(Enum): ...@@ -231,9 +296,36 @@ class OpType(Enum):
""" """
if self.is_shrink_fn(): if self.is_shrink_fn():
return op_dtype, op_dtype, torch.float32 return op_dtype, op_dtype, torch.float32
else: elif self.is_expand_fn():
assert self.is_expand_fn()
return torch.float32, op_dtype, op_dtype return torch.float32, op_dtype, op_dtype
else:
assert self.is_fused_moe_lora_fn()
return op_dtype, op_dtype, op_dtype
def matmul_shapes_fused_moe_lora(
self,
m: int,
n: int,
k: int,
num_loras: int,
num_slices: int,
top_k_num: int,
num_experts: int,
) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]:
if self.is_fused_moe_lora_shrink_fn():
input_shape = (
(m * top_k_num, n)
if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK]
else (m, n)
)
output_shape = (num_slices, m, top_k_num, k)
weight_shape = (num_loras, num_experts, k, n)
else:
assert self.is_fused_moe_lora_expand_fn()
input_shape = (num_slices, m, top_k_num, k)
output_shape = (m, top_k_num, n * num_slices)
weight_shape = (num_loras, num_experts, n, k)
return (input_shape, weight_shape, output_shape)
def matmul_shapes( def matmul_shapes(
self, self,
...@@ -243,6 +335,8 @@ class OpType(Enum): ...@@ -243,6 +335,8 @@ class OpType(Enum):
lora_rank: int, lora_rank: int,
num_loras: int, num_loras: int,
num_slices: int, num_slices: int,
top_k_num: int | None = None,
num_experts: int | None = None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
""" """
Given num_slices, return the shapes of the A, B, and C matrices Given num_slices, return the shapes of the A, B, and C matrices
...@@ -257,6 +351,16 @@ class OpType(Enum): ...@@ -257,6 +351,16 @@ class OpType(Enum):
if self in [OpType.LORA_EXPAND]: if self in [OpType.LORA_EXPAND]:
# LoRA expand kernels support num_slices inherently in the kernel # LoRA expand kernels support num_slices inherently in the kernel
return ((num_slices, m, k), b_shape, (m, n * num_slices)) return ((num_slices, m, k), b_shape, (m, n * num_slices))
if self.is_fused_moe_lora_fn():
return self.matmul_shapes_fused_moe_lora(
m,
k,
n,
num_loras,
num_slices,
top_k_num,
num_experts,
)
raise ValueError(f"Unrecognized op_type {self}") raise ValueError(f"Unrecognized op_type {self}")
def bench_fn(self) -> Callable: def bench_fn(self) -> Callable:
...@@ -264,6 +368,16 @@ class OpType(Enum): ...@@ -264,6 +368,16 @@ class OpType(Enum):
return lora_shrink return lora_shrink
if self == OpType.LORA_EXPAND: if self == OpType.LORA_EXPAND:
return lora_expand return lora_expand
if self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
]:
return fused_moe_lora_shrink
if self in [
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]:
return fused_moe_lora_expand
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
...@@ -316,8 +430,10 @@ class BenchmarkContext: ...@@ -316,8 +430,10 @@ class BenchmarkContext:
lora_rank: int lora_rank: int
sort_by_lora_id: bool sort_by_lora_id: bool
dtype: torch.dtype dtype: torch.dtype
seq_length: Optional[int] = None seq_length: int | None = None
num_slices: Optional[int] = None # num_slices for slice based ops num_experts: int | None = None # num_experts for MoE based ops
top_k_num: int | None = None # top_k for MoE based ops
num_slices: int | None = None # num_slices for slice based ops
def with_seq_length(self, seq_length: int) -> "BenchmarkContext": def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
ctx = copy.copy(self) ctx = copy.copy(self)
...@@ -372,6 +488,11 @@ class BenchmarkTensors: ...@@ -372,6 +488,11 @@ class BenchmarkTensors:
f"{dtype_to_str(self.output.dtype)}" f"{dtype_to_str(self.output.dtype)}"
) )
def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType):
return (
size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size
)
@staticmethod @staticmethod
def make( def make(
ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
...@@ -384,6 +505,8 @@ class BenchmarkTensors: ...@@ -384,6 +505,8 @@ class BenchmarkTensors:
ctx.lora_rank, ctx.lora_rank,
ctx.num_loras, ctx.num_loras,
ctx.num_slices, ctx.num_slices,
ctx.top_k_num,
ctx.num_experts,
) )
a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
input_tensor, lora_weights, output_tensor = make_rand_tensors( input_tensor, lora_weights, output_tensor = make_rand_tensors(
...@@ -431,17 +554,27 @@ class BenchmarkTensors: ...@@ -431,17 +554,27 @@ class BenchmarkTensors:
prompt_lora_indices_tensor, prompt_lora_indices_tensor,
) )
def sanity_check(self) -> None: def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None:
""" """
Fails asserts when non-conformality is detected. Fails asserts when non-conformality is detected.
""" """
num_tokens = self.input.shape[-2] num_tokens = (
self.input.shape[1]
if op_type.is_fused_moe_lora_expand_fn()
else self.input.shape[-2]
)
# check metadata tensors # check metadata tensors
assert torch.sum(self.seq_lens) == num_tokens ## In down shrink case, each token is repeated top_k_num times
assert num_tokens == self.get_num_tokens(
torch.sum(self.seq_lens), ctx.top_k_num, op_type
), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}"
num_seqs = self.seq_lens.shape[0] num_seqs = self.seq_lens.shape[0]
# assert self.seq_start_loc.shape[0] == num_seqs # assert self.seq_start_loc.shape[0] == num_seqs
## In down shrink case, each prompt corresponds to top_k_num sequences
assert self.prompt_lora_mapping.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens assert self.get_num_tokens(
self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type
)
def to_device(self, device: str): def to_device(self, device: str):
""" """
...@@ -470,21 +603,111 @@ class BenchmarkTensors: ...@@ -470,21 +603,111 @@ class BenchmarkTensors:
to_device(field) if field_name != "no_lora_flag_cpu" else field, to_device(field) if field_name != "no_lora_flag_cpu" else field,
) )
def metadata(self) -> tuple[int, int, int]: def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]:
""" """
Return num_seqs, num_tokens and max_seq_len Return num_seqs, num_tokens and max_seq_len
""" """
num_seqs = self.seq_lens.shape[0] num_seqs = self.seq_lens.shape[0]
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0] num_tokens = self.get_num_tokens(
self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type
)
max_seq_len = torch.max(self.seq_lens).item() max_seq_len = torch.max(self.seq_lens).item()
num_slices = len(self.lora_weights_lst) num_slices = len(self.lora_weights_lst)
return num_seqs, num_tokens, max_seq_len, num_slices return num_seqs, num_tokens, max_seq_len, num_slices
def as_lora_shrink_kwargs(self) -> dict[str, Any]: def fused_moe_lora_data_prepare(
self.sanity_check() self,
block_size: int,
token_lora_mapping: torch.Tensor,
ctx: BenchmarkContext,
):
def moe_lora_align_block_size(
topk_ids: torch.Tensor,
token_lora_mapping: torch.Tensor,
block_size: int,
num_experts: int,
max_loras: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
num_tokens = ctx.batch_size
curr_topk_ids = torch.randint(
0,
ctx.num_experts,
(num_tokens, ctx.top_k_num),
device="cuda",
dtype=torch.int32,
)
topk_weights = torch.randint(
0,
ctx.num_experts,
(num_tokens, ctx.top_k_num),
device="cuda",
dtype=torch.int32,
)
(sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = (
moe_lora_align_block_size(
topk_ids=curr_topk_ids,
token_lora_mapping=token_lora_mapping,
block_size=block_size,
num_experts=ctx.num_experts,
max_loras=ctx.num_loras,
)
)
sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1)
expert_ids = expert_ids_lora.view(ctx.num_loras, -1)
num_tokens_post_padded = num_tokens_post_padded_lora
return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded)
def as_lora_shrink_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device) self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata() _, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes. # Sanity check matrix shapes.
i_shape, lw_shape, o_shape = ( i_shape, lw_shape, o_shape = (
...@@ -519,11 +742,13 @@ class BenchmarkTensors: ...@@ -519,11 +742,13 @@ class BenchmarkTensors:
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
} }
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: def as_lora_expand_kwargs(
self.sanity_check() self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device) self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata() _, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes. # Sanity check matrix shapes.
i_shape, lw_shape, o_shape = ( i_shape, lw_shape, o_shape = (
...@@ -560,22 +785,177 @@ class BenchmarkTensors: ...@@ -560,22 +785,177 @@ class BenchmarkTensors:
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
} }
def as_fused_moe_lora_shrink_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = (
self.input.shape,
self.lora_weights_lst[0].shape,
self.output.shape,
)
# Expected input shape : [num_tokens, hidden_size] for gate_up
# Expected input shape : [top_k_num * num_tokens, hidden_size] for down
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size]
assert len(lw_shape) == 4
assert lw_shape[-1] == hidden_size
lora_rank = lw_shape[-2]
# Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert len(o_shape) == 4
assert (
o_shape
== (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank)
if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK]
else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank)
)
kernel_config = get_lora_op_configs(
op_type.name.lower(),
max_loras=lw_shape[0],
batch=num_tokens,
hidden_size=hidden_size,
rank=lora_rank,
num_slices=num_slices,
add_inputs=False,
)
(topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = (
self.fused_moe_lora_data_prepare(
block_size=kernel_config["BLOCK_SIZE_M"],
token_lora_mapping=self.lora_kernel_meta.token_lora_mapping,
ctx=ctx,
)
)
return {
"qcurr_hidden_states": self.input,
"lora_a_stacked": self.lora_weights_lst,
"a_intermediate_cache1": self.output,
"topk_weights": topk_weights,
"sorted_token_ids": sorted_token_ids,
"expert_ids": expert_ids,
"num_tokens_post_padded": num_tokens_post_padded,
"top_k_num": ctx.top_k_num,
"device": self.input.device,
"N": lora_rank,
"M": topk_weights.shape[0],
"EM": sorted_token_ids.shape[1],
"K": self.input.shape[1],
"num_tokens": num_tokens,
"num_experts": ctx.num_experts,
"num_slices": num_slices,
"shrink_block_size_m": kernel_config["BLOCK_SIZE_M"],
"shrink_block_size_n": kernel_config["BLOCK_SIZE_N"],
"shrink_block_size_k": kernel_config["BLOCK_SIZE_K"],
"shrink_group_size_m": kernel_config["GROUP_SIZE_M"],
"shrink_num_warps": kernel_config["NUM_WARPS"],
"shrink_num_stages": kernel_config["NUM_STAGES"],
"shrink_split_k": kernel_config.get("SPLIT_K", 1),
"mul_routed_weight": op_type.is_fused_moe_lora_down_fn(),
}
def as_fused_moe_lora_expand_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = (
self.input.shape,
self.lora_weights_lst[0].shape,
self.output.shape,
)
# Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert len(i_shape) == 4
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[-1]
# Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank]
assert len(lw_shape) == 4
assert lw_shape[-1] == lora_rank
hidden_size = lw_shape[-2]
# Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices]
assert len(o_shape) == 3
assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices)
kernel_config = get_lora_op_configs(
op_type.name.lower(),
max_loras=lw_shape[0],
batch=num_tokens,
hidden_size=hidden_size,
rank=lora_rank,
num_slices=num_slices,
add_inputs=False,
)
(topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = (
self.fused_moe_lora_data_prepare(
block_size=kernel_config["BLOCK_SIZE_M"],
token_lora_mapping=self.lora_kernel_meta.token_lora_mapping,
ctx=ctx,
)
)
return {
"a_intermediate_cache1": self.input,
"lora_b_stacked": self.lora_weights_lst,
"output": self.output,
"topk_weights": topk_weights,
"sorted_token_ids": sorted_token_ids,
"expert_ids": expert_ids,
"num_tokens_post_padded": num_tokens_post_padded,
"top_k_num": ctx.top_k_num,
"device": self.input.device,
"N": lora_rank,
"M": topk_weights.shape[0],
"EM": sorted_token_ids.shape[1],
"K": self.input.shape[1],
"num_tokens": num_tokens,
"num_experts": ctx.num_experts,
"num_slices": num_slices,
"max_lora_rank": lora_rank,
"w1_output_dim_size": lw_shape[2],
"expand_block_size_m": kernel_config["BLOCK_SIZE_M"],
"expand_block_size_n": kernel_config["BLOCK_SIZE_N"],
"expand_block_size_k": kernel_config["BLOCK_SIZE_K"],
"expand_group_size_m": kernel_config["GROUP_SIZE_M"],
"expand_num_warps": kernel_config["NUM_WARPS"],
"expand_num_stages": kernel_config["NUM_STAGES"],
"expand_split_k": kernel_config.get("SPLIT_K", 1),
"mul_routed_weight": op_type.is_fused_moe_lora_down_fn(),
}
def bench_fn_kwargs( def bench_fn_kwargs(
self, op_type: OpType, add_inputs: Optional[bool] = None self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None
) -> dict[str, Any]: ) -> dict[str, Any]:
if op_type.is_shrink_fn(): if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn():
assert add_inputs is None assert add_inputs is None
else: else:
assert add_inputs is not None assert add_inputs is not None
if op_type == OpType.LORA_SHRINK: if op_type == OpType.LORA_SHRINK:
return self.as_lora_shrink_kwargs() return self.as_lora_shrink_kwargs(ctx, op_type)
if op_type == OpType.LORA_EXPAND: if op_type == OpType.LORA_EXPAND:
return self.as_lora_expand_kwargs(add_inputs) return self.as_lora_expand_kwargs(ctx, op_type, add_inputs)
if op_type.is_fused_moe_lora_shrink_fn():
return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type)
if op_type.is_fused_moe_lora_expand_fn():
return self.as_fused_moe_lora_expand_kwargs(ctx, op_type)
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
def test_correctness( def test_correctness(
self, op_type: OpType, expand_fn_add_inputs: Optional[bool] self, op_type: OpType, expand_fn_add_inputs: bool | None
) -> bool: ) -> bool:
""" """
Test correctness of op_type implementation against a grouped gemm Test correctness of op_type implementation against a grouped gemm
...@@ -611,12 +991,12 @@ def bench_optype( ...@@ -611,12 +991,12 @@ def bench_optype(
ctx: BenchmarkContext, ctx: BenchmarkContext,
arg_pool_size: int, arg_pool_size: int,
op_type: OpType, op_type: OpType,
cuda_graph_nops: Optional[int] = None, cuda_graph_nops: int | None = None,
expand_fn_add_inputs: Optional[bool] = None, expand_fn_add_inputs: bool | None = None,
test_correctness: bool = False, test_correctness: bool = False,
) -> TMeasurement: ) -> TMeasurement:
assert arg_pool_size >= 1 assert arg_pool_size >= 1
if op_type.is_shrink_fn(): if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn():
assert expand_fn_add_inputs is None assert expand_fn_add_inputs is None
else: else:
assert expand_fn_add_inputs is not None assert expand_fn_add_inputs is not None
...@@ -626,23 +1006,30 @@ def bench_optype( ...@@ -626,23 +1006,30 @@ def bench_optype(
BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
] ]
for bt in bench_tensors: for bt in bench_tensors:
bt.sanity_check() bt.sanity_check(ctx, op_type)
# Test correctness of our implementation. # Test correctness of our implementation.
if test_correctness: if test_correctness:
assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], (
f"Correctness testing is not supported for {op_type.name}."
)
assert all( assert all(
[bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] [
bt.test_correctness(ctx, op_type, expand_fn_add_inputs)
for bt in bench_tensors
]
) )
# BenchmarkTensors -> dict (kwargs) # BenchmarkTensors -> dict (kwargs)
kwargs_list = [ kwargs_list = [
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs)
for bt in bench_tensors for bt in bench_tensors
] ]
# Clear LoRA optimization hash-maps. # Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT.clear() _LORA_A_PTR_DICT.clear()
_LORA_B_PTR_DICT.clear() _LORA_B_PTR_DICT.clear()
_LORA_PTR_DICT.clear()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for kwargs in kwargs_list: for kwargs in kwargs_list:
op_type.bench_fn()(**kwargs) op_type.bench_fn()(**kwargs)
...@@ -679,7 +1066,7 @@ def bench_torch_mm( ...@@ -679,7 +1066,7 @@ def bench_torch_mm(
ctx: BenchmarkContext, ctx: BenchmarkContext,
arg_pool_size: int, arg_pool_size: int,
op_type: OpType, op_type: OpType,
cuda_graph_nops: Optional[int] = None, cuda_graph_nops: int | None = None,
) -> TMeasurement: ) -> TMeasurement:
""" """
Benchmark basic torch.mm as a roofline. Benchmark basic torch.mm as a roofline.
...@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str: ...@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str:
""" """
def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None): def print_timers(timers: list[TMeasurement], args: argparse.Namespace | None = None):
compare = TBenchmark.Compare(timers) compare = TBenchmark.Compare(timers)
compare.print() compare.print()
...@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): ...@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
# Benchmark bench_op # Benchmark bench_op
expand_fn_add_inputs = ( expand_fn_add_inputs = (
[None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs [None]
if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn()
else args.expand_fn_add_inputs
) )
for add_input_arg in expand_fn_add_inputs: for add_input_arg in expand_fn_add_inputs:
seq_len_timers.append( seq_len_timers.append(
...@@ -830,12 +1219,22 @@ def as_benchmark_contexts( ...@@ -830,12 +1219,22 @@ def as_benchmark_contexts(
hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
) -> list[BenchmarkContext]: ) -> list[BenchmarkContext]:
ctxs: list[BenchmarkContext] = [] ctxs: list[BenchmarkContext] = []
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa for (
batch_size,
hidden_size,
lora_rank,
num_loras,
sort_by_lora_id,
top_k_num,
num_experts,
) in product( # noqa
args.batch_sizes, args.batch_sizes,
list(hidden_sizes), list(hidden_sizes),
lora_ranks, lora_ranks,
args.num_loras, args.num_loras,
args.sort_by_lora_id, args.sort_by_lora_id,
args.top_k_nums,
args.num_experts,
): ):
ctxs.append( ctxs.append(
BenchmarkContext( BenchmarkContext(
...@@ -850,6 +1249,8 @@ def as_benchmark_contexts( ...@@ -850,6 +1249,8 @@ def as_benchmark_contexts(
seq_length=None, seq_length=None,
sort_by_lora_id=sort_by_lora_id, sort_by_lora_id=sort_by_lora_id,
dtype=args.dtype, dtype=args.dtype,
top_k_num=top_k_num,
num_experts=num_experts,
# To be filled based on the OpType to benchmark # To be filled based on the OpType to benchmark
num_slices=None, num_slices=None,
) )
...@@ -1011,6 +1412,22 @@ if __name__ == "__main__": ...@@ -1011,6 +1412,22 @@ if __name__ == "__main__":
), ),
) )
p.add_argument(
"--top-k-nums",
nargs="+",
type=int,
default=DEFAULT_TOP_K_NUMS,
help="Top-K values for MoE LoRA operations",
)
p.add_argument(
"--num-experts",
nargs="+",
type=int,
default=DEFAULT_NUM_EXPERTS,
help="Number of experts for MoE LoRA operations",
)
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description=f""" description=f"""
Benchmark LoRA kernels: Benchmark LoRA kernels:
......
...@@ -8,10 +8,9 @@ import math ...@@ -8,10 +8,9 @@ import math
import os import os
import pickle as pkl import pickle as pkl
import time import time
from collections.abc import Iterable from collections.abc import Callable, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Callable, Optional
import pandas as pd import pandas as pd
import torch import torch
...@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights, quantize_weights,
) )
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
...@@ -63,23 +62,23 @@ class BenchmarkTensors: ...@@ -63,23 +62,23 @@ class BenchmarkTensors:
a: torch.Tensor a: torch.Tensor
w_q: torch.Tensor w_q: torch.Tensor
group_size: Optional[int] group_size: int | None
wtype: ScalarType wtype: ScalarType
w_g_s: torch.Tensor w_g_s: torch.Tensor
w_g_zp: Optional[torch.Tensor] w_g_zp: torch.Tensor | None
w_ch_s: Optional[torch.Tensor] w_ch_s: torch.Tensor | None
w_tok_s: Optional[torch.Tensor] w_tok_s: torch.Tensor | None
@dataclass @dataclass
class TypeConfig: class TypeConfig:
act_type: torch.dtype act_type: torch.dtype
weight_type: ScalarType weight_type: ScalarType
output_type: Optional[torch.dtype] output_type: torch.dtype | None
group_scale_type: Optional[torch.dtype] group_scale_type: torch.dtype | None
group_zero_type: Optional[torch.dtype] group_zero_type: torch.dtype | None
channel_scale_type: Optional[torch.dtype] channel_scale_type: torch.dtype | None
token_scale_type: Optional[torch.dtype] token_scale_type: torch.dtype | None
def rand_data(shape, dtype=torch.float16, scale=1): def rand_data(shape, dtype=torch.float16, scale=1):
...@@ -93,8 +92,8 @@ def quantize_and_pack( ...@@ -93,8 +92,8 @@ def quantize_and_pack(
atype: torch.dtype, atype: torch.dtype,
w: torch.Tensor, w: torch.Tensor,
wtype: ScalarType, wtype: ScalarType,
stype: Optional[torch.dtype], stype: torch.dtype | None,
group_size: Optional[int], group_size: int | None,
zero_points: bool = False, zero_points: bool = False,
): ):
assert wtype.is_integer(), "TODO: support floating point weights" assert wtype.is_integer(), "TODO: support floating point weights"
...@@ -113,7 +112,7 @@ def quantize_and_pack( ...@@ -113,7 +112,7 @@ def quantize_and_pack(
def create_bench_tensors( def create_bench_tensors(
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
) -> list[BenchmarkTensors]: ) -> list[BenchmarkTensors]:
m, n, k = shape m, n, k = shape
...@@ -238,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: ...@@ -238,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
b_q_weight=w_q, b_q_weight=w_q,
b_bias=None, b_bias=None,
b_scales=w_s, b_scales=w_s,
a_scales=None,
global_scale=None, global_scale=None,
b_zeros=w_zp, b_zeros=w_zp,
g_idx=g_idx, g_idx=g_idx,
...@@ -331,8 +331,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]) ...@@ -331,8 +331,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
return res return res
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None _SWEEP_SCHEDULES_RESULTS_CSV: str | None = None
def bench( def bench(
......
...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
sort_weights, sort_weights,
) )
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
...@@ -263,7 +263,7 @@ def bench_run( ...@@ -263,7 +263,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
...@@ -273,7 +273,7 @@ def bench_run( ...@@ -273,7 +273,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
# 移除全局的 current_platform 导入,改为在需要时局部导入 # 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype() # FP8_DTYPE = current_platform.fp8_dtype()
...@@ -228,8 +228,8 @@ def benchmark_config( ...@@ -228,8 +228,8 @@ def benchmark_config(
# run() # run()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
...@@ -253,10 +253,10 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False): ...@@ -253,10 +253,10 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_k_range = [32, 64, 128, 256] block_k_range = [32, 64, 128, 256]
if not use_fp16: if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [2, 4, 8] num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 16, 32, 64] group_m_range = [1, 4, 8, 16, 32]
num_stage_range = [2, 3, 4, 5] num_stage_range = [2]
# waves_per_eu_range = [0] # waves_per_eu_range = [0, 1, 2, 4]
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] # matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else [] # kpack_range = [1, 2] if use_fp16 else []
...@@ -669,19 +669,23 @@ def main(args: argparse.Namespace): ...@@ -669,19 +669,23 @@ def main(args: argparse.Namespace):
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM": elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM", "DeepseekV32ForCausalLM",
"Glm4MoeForCausalLM", "Glm4MoeForCausalLM",
"NemotronHForCausalLM",
): ):
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM", "Qwen3MoeForCausalLM",
...@@ -690,14 +694,27 @@ def main(args: argparse.Namespace): ...@@ -690,14 +694,27 @@ def main(args: argparse.Namespace):
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
text_config = config.get_text_config()
E = text_config.num_experts
topk = text_config.num_experts_per_tok
intermediate_size = text_config.moe_intermediate_size
hidden_size = text_config.hidden_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
E = config.num_experts E = config.num_experts
topk = config.moe_topk[0] topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0] intermediate_size = config.moe_intermediate_size[0]
hidden_size = config.hidden_size
elif config.architectures[0] in ("Step3VLForConditionalGeneration"): elif config.architectures[0] in ("Step3VLForConditionalGeneration"):
E = config.text_config.moe_num_experts E = config.text_config.moe_num_experts
topk = config.text_config.moe_top_k topk = config.text_config.moe_top_k
intermediate_size = config.text_config.moe_intermediate_size intermediate_size = config.text_config.moe_intermediate_size
elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
E = config.thinker_config.text_config.num_experts
topk = config.thinker_config.text_config.num_experts_per_tok
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
hidden_size = config.thinker_config.text_config.hidden_size
else: else:
# Support for llama4 # Support for llama4
config = config.get_text_config() config = config.get_text_config()
...@@ -705,16 +722,16 @@ def main(args: argparse.Namespace): ...@@ -705,16 +722,16 @@ def main(args: argparse.Namespace):
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
enable_ep = bool(args.enable_expert_parallel) enable_ep = bool(args.enable_expert_parallel)
if enable_ep: if enable_ep:
ensure_divisibility(E, tp_size, "Number of experts") ensure_divisibility(E, tp_size, "Number of experts")
E = E // tp_size E = E // tp_size
shard_intermediate_size = 2 * intermediate_size shard_intermediate_size = 2 * intermediate_size
else: else:
ensure_divisibility(intermediate_size, tp_size, "intermediate_size") ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.dtype
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)
......
...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( ...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
) )
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -105,8 +105,8 @@ def benchmark_permute( ...@@ -105,8 +105,8 @@ def benchmark_permute(
graph.replay() graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
...@@ -241,8 +241,8 @@ def benchmark_unpermute( ...@@ -241,8 +241,8 @@ def benchmark_unpermute(
graph.replay() graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
...@@ -344,7 +344,7 @@ def main(args: argparse.Namespace): ...@@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute use_customized_permute = args.use_customized_permute
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# #
# The CSV file (named with current date/time) contains these columns: # The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position, # model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99, # is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max, # torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup # speedup
# #
...@@ -39,7 +39,7 @@ import torch ...@@ -39,7 +39,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -86,9 +86,8 @@ def benchmark_mrope( ...@@ -86,9 +86,8 @@ def benchmark_mrope(
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
max_position: int = 8192, max_position: int = 8192,
rope_theta: float = 10000,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: dict[str, Any] = None, rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
seed: int = 0, seed: int = 0,
warmup_iter: int = 10, warmup_iter: int = 10,
...@@ -102,9 +101,8 @@ def benchmark_mrope( ...@@ -102,9 +101,8 @@ def benchmark_mrope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim, rotary_dim=head_dim,
max_position=max_position, max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_scaling=rope_scaling, rope_parameters=rope_parameters,
dtype=dtype, dtype=dtype,
).to(device=device) ).to(device=device)
...@@ -203,9 +201,8 @@ def benchmark_mrope( ...@@ -203,9 +201,8 @@ def benchmark_mrope(
num_kv_heads, num_kv_heads,
head_dim, head_dim,
max_position, max_position,
rope_theta,
is_neox_style, is_neox_style,
str(rope_scaling), str(rope_parameters),
str(dtype).split(".")[-1], str(dtype).split(".")[-1],
torch_stats["mean"], torch_stats["mean"],
torch_stats["median"], torch_stats["median"],
...@@ -255,9 +252,8 @@ if __name__ == "__main__": ...@@ -255,9 +252,8 @@ if __name__ == "__main__":
"num_kv_heads", "num_kv_heads",
"head_dim", "head_dim",
"max_position", "max_position",
"rope_theta",
"is_neox_style", "is_neox_style",
"rope_scaling", "rope_parameters",
"dtype", "dtype",
"torch_mean", "torch_mean",
"torch_median", "torch_median",
...@@ -303,7 +299,7 @@ if __name__ == "__main__": ...@@ -303,7 +299,7 @@ if __name__ == "__main__":
q_size = num_heads * head_dim q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim kv_size = num_kv_heads * head_dim
is_neox_style = True is_neox_style = True
rope_theta = config.rope_theta rope_parameters = config.rope_parameters
max_position = config.max_position_embeddings max_position = config.max_position_embeddings
for num_tokens in num_tokens_list: for num_tokens in num_tokens_list:
...@@ -315,9 +311,8 @@ if __name__ == "__main__": ...@@ -315,9 +311,8 @@ if __name__ == "__main__":
num_heads=num_heads, num_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
max_position=max_position, max_position=max_position,
rope_theta=rope_theta,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling, rope_parameters=rope_parameters,
dtype=getattr(torch, args.dtype), dtype=getattr(torch, args.dtype),
seed=args.seed, seed=args.seed,
warmup_iter=args.warmup_iter, warmup_iter=args.warmup_iter,
......
...@@ -3,20 +3,18 @@ ...@@ -3,20 +3,18 @@
import random import random
import time import time
from typing import Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
import vllm.envs as envs from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
from vllm.utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random, create_kv_caches_with_random,
) )
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -39,7 +37,7 @@ def main( ...@@ -39,7 +37,7 @@ def main(
seed: int, seed: int,
do_profile: bool, do_profile: bool,
device: str = "cuda", device: str = "cuda",
kv_cache_dtype: Optional[str] = None, kv_cache_dtype: str | None = None,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import argparse import argparse
import math import math
from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -30,8 +30,8 @@ def _time_cuda( ...@@ -30,8 +30,8 @@ def _time_cuda(
fn() fn()
torch.cuda.synchronize() torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.Event(enable_timing=True)
start.record() start.record()
for _ in range(bench_iters): for _ in range(bench_iters):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import torch
from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
def polynorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
def norm(x, eps: float):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.float()
return (
(
weight[0] * norm(x**3, eps)
+ weight[1] * norm(x**2, eps)
+ weight[2] * norm(x, eps)
+ bias
)
.to(weight.dtype)
.view(orig_shape)
)
def polynorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
out = torch.empty_like(x)
vllm_ops.poly_norm(out, x, weight, bias, eps)
output = out
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_dim):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
output_naive = polynorm_naive(x, weight, bias)
output_vllm = polynorm_vllm(x, weight, bias)
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
dim_range = [2048, 4096]
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
def get_benchmark():
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["dim", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "vllm"],
line_names=["Naive", "vLLM"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="polynorm-perf",
args={},
)
)
def benchmark(dim, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_dim = dim * 4
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_naive(x, weight, bias),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_vllm(x, weight, bias),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size",
)
parser.add_argument(
"--seq-len",
type=int,
default=128,
help="Sequence length",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=8192,
help="Intermediate size of MLP",
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/polnorm/",
help="Path to save polnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=args.batch_size,
seq_len=args.seq_len,
hidden_dim=args.hidden_dim,
)
benchmark = get_benchmark()
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode() @torch.inference_mode()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import time
import torch
from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
create_kv_caches_with_random,
)
logger = init_logger(__name__)
@torch.inference_mode()
def run_benchmark(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
kv_cache_dtype: str,
num_iters: int,
benchmark_mode: str,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""
if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
current_platform.seed_everything(42)
torch.set_default_device(device)
# create random key / value tensors [T, H, D].
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
value = torch.randn_like(key)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots = block_size * num_blocks
if num_tokens > num_slots:
raise ValueError("num_tokens cannot exceed the total number of cache slots")
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
key_caches, value_caches = create_kv_caches_with_random(
num_blocks,
block_size,
1, # num_layers
num_heads,
head_size,
kv_cache_dtype,
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
function_under_test = lambda: ops.reshape_and_cache(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
function_under_test()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters
# warm-up
run_cuda_benchmark(3)
lat = run_cuda_benchmark(num_iters)
# free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache()
return lat
def main(args):
rows = []
for exp in range(1, 17):
n_tok = 2**exp
lat = run_benchmark(
num_tokens=n_tok,
num_heads=args.num_heads,
head_size=args.head_size,
block_size=args.block_size,
num_blocks=args.num_blocks,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
kv_cache_dtype=args.kv_cache_dtype,
num_iters=args.iters,
benchmark_mode=args.mode,
device="cuda",
)
rows.append([n_tok, lat * 1e6]) # convert to microseconds
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--num-heads", type=int, default=128)
parser.add_argument(
"--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--num-blocks", type=int, default=128 * 128)
parser.add_argument(
"--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="bfloat16",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
default="auto",
)
parser.add_argument("--iters", type=int, default=200)
parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment