Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
......@@ -7,7 +7,8 @@ import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode()
......
......@@ -6,11 +6,12 @@ import copy
import json
import pickle
import time
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum, auto
from itertools import product
from pathlib import Path
from typing import Any, Callable, Optional
from typing import Any
import torch
import torch.utils.benchmark as TBenchmark
......@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES
from vllm.triton_utils import HAS_TRITON
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.triton_utils import HAS_TRITON, triton
if HAS_TRITON:
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora
LoRAKernelMeta,
fused_moe_lora_expand,
fused_moe_lora_shrink,
lora_expand,
lora_shrink,
)
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
_LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora
)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.utils import FlexibleArgumentParser
from vllm import _custom_ops as ops
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.math_utils import round_up
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_TP_SIZES = [1]
......@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS = [False, True]
DEFAULT_SEQ_LENGTHS = [1]
DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]
DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k
DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts
# Utilities
......@@ -158,7 +172,7 @@ def ref_group_gemm(
seq_lens_cpu: torch.Tensor,
prompt_lora_mapping_cpu: torch.Tensor,
scaling: float,
add_inputs: Optional[bool],
add_inputs: bool | None,
):
"""
Torch group gemm reference implementation to test correctness of
......@@ -190,6 +204,11 @@ class OpType(Enum):
LORA_SHRINK = auto()
LORA_EXPAND = auto()
## Adding support for fused moe lora
FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink
FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand
FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink
FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand
@staticmethod
def from_str(s: str) -> "OpType":
......@@ -197,6 +216,15 @@ class OpType(Enum):
return OpType.LORA_SHRINK
if s.lower() == "lora_expand":
return OpType.LORA_EXPAND
# Adding support for fused moe lora, both in gate_up and down
if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink
return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK
if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand
return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND
if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink
return OpType.FUSED_MOE_LORA_DOWN_SHRINK
if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand
return OpType.FUSED_MOE_LORA_DOWN_EXPAND
raise ValueError(f"Unrecognized str {s} to convert to OpType")
def is_shrink_fn(self) -> bool:
......@@ -205,19 +233,56 @@ class OpType(Enum):
def is_expand_fn(self) -> bool:
return self in [OpType.LORA_EXPAND]
def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def is_fused_moe_lora_gate_up_fn(
self,
) -> bool: ## adding for fused MoE LoRA Gate/Up
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
]
def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down
return self in [
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def is_fused_moe_lora_shrink_fn(self) -> bool:
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
]
def is_fused_moe_lora_expand_fn(self) -> bool:
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def num_slices(self) -> list[int]:
if self.is_fused_moe_lora_gate_up_fn():
return [2]
elif self.is_fused_moe_lora_down_fn():
return [1]
return [1, 2, 3]
def mkn(
self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
) -> tuple[int, int, int]:
num_tokens = batch_size * seq_length
if self.is_shrink_fn():
if self.is_shrink_fn() or self.is_fused_moe_lora_fn():
m = num_tokens
k = hidden_size
n = lora_rank
else:
assert self.is_expand_fn()
elif self.is_expand_fn():
m = num_tokens
k = lora_rank
n = hidden_size
......@@ -231,9 +296,36 @@ class OpType(Enum):
"""
if self.is_shrink_fn():
return op_dtype, op_dtype, torch.float32
else:
assert self.is_expand_fn()
elif self.is_expand_fn():
return torch.float32, op_dtype, op_dtype
else:
assert self.is_fused_moe_lora_fn()
return op_dtype, op_dtype, op_dtype
def matmul_shapes_fused_moe_lora(
self,
m: int,
n: int,
k: int,
num_loras: int,
num_slices: int,
top_k_num: int,
num_experts: int,
) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]:
if self.is_fused_moe_lora_shrink_fn():
input_shape = (
(m * top_k_num, n)
if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK]
else (m, n)
)
output_shape = (num_slices, m, top_k_num, k)
weight_shape = (num_loras, num_experts, k, n)
else:
assert self.is_fused_moe_lora_expand_fn()
input_shape = (num_slices, m, top_k_num, k)
output_shape = (m, top_k_num, n * num_slices)
weight_shape = (num_loras, num_experts, n, k)
return (input_shape, weight_shape, output_shape)
def matmul_shapes(
self,
......@@ -243,6 +335,8 @@ class OpType(Enum):
lora_rank: int,
num_loras: int,
num_slices: int,
top_k_num: int | None = None,
num_experts: int | None = None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
......@@ -257,6 +351,16 @@ class OpType(Enum):
if self in [OpType.LORA_EXPAND]:
# LoRA expand kernels support num_slices inherently in the kernel
return ((num_slices, m, k), b_shape, (m, n * num_slices))
if self.is_fused_moe_lora_fn():
return self.matmul_shapes_fused_moe_lora(
m,
k,
n,
num_loras,
num_slices,
top_k_num,
num_experts,
)
raise ValueError(f"Unrecognized op_type {self}")
def bench_fn(self) -> Callable:
......@@ -264,6 +368,16 @@ class OpType(Enum):
return lora_shrink
if self == OpType.LORA_EXPAND:
return lora_expand
if self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
]:
return fused_moe_lora_shrink
if self in [
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]:
return fused_moe_lora_expand
raise ValueError(f"Unrecognized optype {self}")
......@@ -316,8 +430,10 @@ class BenchmarkContext:
lora_rank: int
sort_by_lora_id: bool
dtype: torch.dtype
seq_length: Optional[int] = None
num_slices: Optional[int] = None # num_slices for slice based ops
seq_length: int | None = None
num_experts: int | None = None # num_experts for MoE based ops
top_k_num: int | None = None # top_k for MoE based ops
num_slices: int | None = None # num_slices for slice based ops
def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
ctx = copy.copy(self)
......@@ -372,6 +488,11 @@ class BenchmarkTensors:
f"{dtype_to_str(self.output.dtype)}"
)
def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType):
return (
size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size
)
@staticmethod
def make(
ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
......@@ -384,6 +505,8 @@ class BenchmarkTensors:
ctx.lora_rank,
ctx.num_loras,
ctx.num_slices,
ctx.top_k_num,
ctx.num_experts,
)
a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
input_tensor, lora_weights, output_tensor = make_rand_tensors(
......@@ -431,17 +554,27 @@ class BenchmarkTensors:
prompt_lora_indices_tensor,
)
def sanity_check(self) -> None:
def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None:
"""
Fails asserts when non-conformality is detected.
"""
num_tokens = self.input.shape[-2]
num_tokens = (
self.input.shape[1]
if op_type.is_fused_moe_lora_expand_fn()
else self.input.shape[-2]
)
# check metadata tensors
assert torch.sum(self.seq_lens) == num_tokens
## In down shrink case, each token is repeated top_k_num times
assert num_tokens == self.get_num_tokens(
torch.sum(self.seq_lens), ctx.top_k_num, op_type
), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}"
num_seqs = self.seq_lens.shape[0]
# assert self.seq_start_loc.shape[0] == num_seqs
## In down shrink case, each prompt corresponds to top_k_num sequences
assert self.prompt_lora_mapping.shape[0] == num_seqs
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
assert self.get_num_tokens(
self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type
)
def to_device(self, device: str):
"""
......@@ -470,21 +603,111 @@ class BenchmarkTensors:
to_device(field) if field_name != "no_lora_flag_cpu" else field,
)
def metadata(self) -> tuple[int, int, int]:
def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]:
"""
Return num_seqs, num_tokens and max_seq_len
"""
num_seqs = self.seq_lens.shape[0]
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
num_tokens = self.get_num_tokens(
self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type
)
max_seq_len = torch.max(self.seq_lens).item()
num_slices = len(self.lora_weights_lst)
return num_seqs, num_tokens, max_seq_len, num_slices
def as_lora_shrink_kwargs(self) -> dict[str, Any]:
self.sanity_check()
def fused_moe_lora_data_prepare(
self,
block_size: int,
token_lora_mapping: torch.Tensor,
ctx: BenchmarkContext,
):
def moe_lora_align_block_size(
topk_ids: torch.Tensor,
token_lora_mapping: torch.Tensor,
block_size: int,
num_experts: int,
max_loras: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
num_tokens = ctx.batch_size
curr_topk_ids = torch.randint(
0,
ctx.num_experts,
(num_tokens, ctx.top_k_num),
device="cuda",
dtype=torch.int32,
)
topk_weights = torch.randint(
0,
ctx.num_experts,
(num_tokens, ctx.top_k_num),
device="cuda",
dtype=torch.int32,
)
(sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = (
moe_lora_align_block_size(
topk_ids=curr_topk_ids,
token_lora_mapping=token_lora_mapping,
block_size=block_size,
num_experts=ctx.num_experts,
max_loras=ctx.num_loras,
)
)
sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1)
expert_ids = expert_ids_lora.view(ctx.num_loras, -1)
num_tokens_post_padded = num_tokens_post_padded_lora
return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded)
def as_lora_shrink_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata()
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = (
......@@ -519,11 +742,13 @@ class BenchmarkTensors:
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
}
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
self.sanity_check()
def as_lora_expand_kwargs(
self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata()
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = (
......@@ -560,22 +785,177 @@ class BenchmarkTensors:
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
}
def as_fused_moe_lora_shrink_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = (
self.input.shape,
self.lora_weights_lst[0].shape,
self.output.shape,
)
# Expected input shape : [num_tokens, hidden_size] for gate_up
# Expected input shape : [top_k_num * num_tokens, hidden_size] for down
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size]
assert len(lw_shape) == 4
assert lw_shape[-1] == hidden_size
lora_rank = lw_shape[-2]
# Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert len(o_shape) == 4
assert (
o_shape
== (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank)
if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK]
else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank)
)
kernel_config = get_lora_op_configs(
op_type.name.lower(),
max_loras=lw_shape[0],
batch=num_tokens,
hidden_size=hidden_size,
rank=lora_rank,
num_slices=num_slices,
add_inputs=False,
)
(topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = (
self.fused_moe_lora_data_prepare(
block_size=kernel_config["BLOCK_SIZE_M"],
token_lora_mapping=self.lora_kernel_meta.token_lora_mapping,
ctx=ctx,
)
)
return {
"qcurr_hidden_states": self.input,
"lora_a_stacked": self.lora_weights_lst,
"a_intermediate_cache1": self.output,
"topk_weights": topk_weights,
"sorted_token_ids": sorted_token_ids,
"expert_ids": expert_ids,
"num_tokens_post_padded": num_tokens_post_padded,
"top_k_num": ctx.top_k_num,
"device": self.input.device,
"N": lora_rank,
"M": topk_weights.shape[0],
"EM": sorted_token_ids.shape[1],
"K": self.input.shape[1],
"num_tokens": num_tokens,
"num_experts": ctx.num_experts,
"num_slices": num_slices,
"shrink_block_size_m": kernel_config["BLOCK_SIZE_M"],
"shrink_block_size_n": kernel_config["BLOCK_SIZE_N"],
"shrink_block_size_k": kernel_config["BLOCK_SIZE_K"],
"shrink_group_size_m": kernel_config["GROUP_SIZE_M"],
"shrink_num_warps": kernel_config["NUM_WARPS"],
"shrink_num_stages": kernel_config["NUM_STAGES"],
"shrink_split_k": kernel_config.get("SPLIT_K", 1),
"mul_routed_weight": op_type.is_fused_moe_lora_down_fn(),
}
def as_fused_moe_lora_expand_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = (
self.input.shape,
self.lora_weights_lst[0].shape,
self.output.shape,
)
# Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert len(i_shape) == 4
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[-1]
# Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank]
assert len(lw_shape) == 4
assert lw_shape[-1] == lora_rank
hidden_size = lw_shape[-2]
# Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices]
assert len(o_shape) == 3
assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices)
kernel_config = get_lora_op_configs(
op_type.name.lower(),
max_loras=lw_shape[0],
batch=num_tokens,
hidden_size=hidden_size,
rank=lora_rank,
num_slices=num_slices,
add_inputs=False,
)
(topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = (
self.fused_moe_lora_data_prepare(
block_size=kernel_config["BLOCK_SIZE_M"],
token_lora_mapping=self.lora_kernel_meta.token_lora_mapping,
ctx=ctx,
)
)
return {
"a_intermediate_cache1": self.input,
"lora_b_stacked": self.lora_weights_lst,
"output": self.output,
"topk_weights": topk_weights,
"sorted_token_ids": sorted_token_ids,
"expert_ids": expert_ids,
"num_tokens_post_padded": num_tokens_post_padded,
"top_k_num": ctx.top_k_num,
"device": self.input.device,
"N": lora_rank,
"M": topk_weights.shape[0],
"EM": sorted_token_ids.shape[1],
"K": self.input.shape[1],
"num_tokens": num_tokens,
"num_experts": ctx.num_experts,
"num_slices": num_slices,
"max_lora_rank": lora_rank,
"w1_output_dim_size": lw_shape[2],
"expand_block_size_m": kernel_config["BLOCK_SIZE_M"],
"expand_block_size_n": kernel_config["BLOCK_SIZE_N"],
"expand_block_size_k": kernel_config["BLOCK_SIZE_K"],
"expand_group_size_m": kernel_config["GROUP_SIZE_M"],
"expand_num_warps": kernel_config["NUM_WARPS"],
"expand_num_stages": kernel_config["NUM_STAGES"],
"expand_split_k": kernel_config.get("SPLIT_K", 1),
"mul_routed_weight": op_type.is_fused_moe_lora_down_fn(),
}
def bench_fn_kwargs(
self, op_type: OpType, add_inputs: Optional[bool] = None
self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None
) -> dict[str, Any]:
if op_type.is_shrink_fn():
if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn():
assert add_inputs is None
else:
assert add_inputs is not None
if op_type == OpType.LORA_SHRINK:
return self.as_lora_shrink_kwargs()
return self.as_lora_shrink_kwargs(ctx, op_type)
if op_type == OpType.LORA_EXPAND:
return self.as_lora_expand_kwargs(add_inputs)
return self.as_lora_expand_kwargs(ctx, op_type, add_inputs)
if op_type.is_fused_moe_lora_shrink_fn():
return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type)
if op_type.is_fused_moe_lora_expand_fn():
return self.as_fused_moe_lora_expand_kwargs(ctx, op_type)
raise ValueError(f"Unrecognized optype {self}")
def test_correctness(
self, op_type: OpType, expand_fn_add_inputs: Optional[bool]
self, op_type: OpType, expand_fn_add_inputs: bool | None
) -> bool:
"""
Test correctness of op_type implementation against a grouped gemm
......@@ -611,12 +991,12 @@ def bench_optype(
ctx: BenchmarkContext,
arg_pool_size: int,
op_type: OpType,
cuda_graph_nops: Optional[int] = None,
expand_fn_add_inputs: Optional[bool] = None,
cuda_graph_nops: int | None = None,
expand_fn_add_inputs: bool | None = None,
test_correctness: bool = False,
) -> TMeasurement:
assert arg_pool_size >= 1
if op_type.is_shrink_fn():
if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn():
assert expand_fn_add_inputs is None
else:
assert expand_fn_add_inputs is not None
......@@ -626,23 +1006,30 @@ def bench_optype(
BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
]
for bt in bench_tensors:
bt.sanity_check()
bt.sanity_check(ctx, op_type)
# Test correctness of our implementation.
if test_correctness:
assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], (
f"Correctness testing is not supported for {op_type.name}."
)
assert all(
[bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]
[
bt.test_correctness(ctx, op_type, expand_fn_add_inputs)
for bt in bench_tensors
]
)
# BenchmarkTensors -> dict (kwargs)
kwargs_list = [
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs)
for bt in bench_tensors
]
# Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT.clear()
_LORA_B_PTR_DICT.clear()
_LORA_PTR_DICT.clear()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for kwargs in kwargs_list:
op_type.bench_fn()(**kwargs)
......@@ -679,7 +1066,7 @@ def bench_torch_mm(
ctx: BenchmarkContext,
arg_pool_size: int,
op_type: OpType,
cuda_graph_nops: Optional[int] = None,
cuda_graph_nops: int | None = None,
) -> TMeasurement:
"""
Benchmark basic torch.mm as a roofline.
......@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str:
"""
def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None):
def print_timers(timers: list[TMeasurement], args: argparse.Namespace | None = None):
compare = TBenchmark.Compare(timers)
compare.print()
......@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
# Benchmark bench_op
expand_fn_add_inputs = (
[None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
[None]
if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn()
else args.expand_fn_add_inputs
)
for add_input_arg in expand_fn_add_inputs:
seq_len_timers.append(
......@@ -830,12 +1219,22 @@ def as_benchmark_contexts(
hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
) -> list[BenchmarkContext]:
ctxs: list[BenchmarkContext] = []
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
for (
batch_size,
hidden_size,
lora_rank,
num_loras,
sort_by_lora_id,
top_k_num,
num_experts,
) in product( # noqa
args.batch_sizes,
list(hidden_sizes),
lora_ranks,
args.num_loras,
args.sort_by_lora_id,
args.top_k_nums,
args.num_experts,
):
ctxs.append(
BenchmarkContext(
......@@ -850,6 +1249,8 @@ def as_benchmark_contexts(
seq_length=None,
sort_by_lora_id=sort_by_lora_id,
dtype=args.dtype,
top_k_num=top_k_num,
num_experts=num_experts,
# To be filled based on the OpType to benchmark
num_slices=None,
)
......@@ -1011,6 +1412,22 @@ if __name__ == "__main__":
),
)
p.add_argument(
"--top-k-nums",
nargs="+",
type=int,
default=DEFAULT_TOP_K_NUMS,
help="Top-K values for MoE LoRA operations",
)
p.add_argument(
"--num-experts",
nargs="+",
type=int,
default=DEFAULT_NUM_EXPERTS,
help="Number of experts for MoE LoRA operations",
)
parser = FlexibleArgumentParser(
description=f"""
Benchmark LoRA kernels:
......
......@@ -8,10 +8,9 @@ import math
import os
import pickle as pkl
import time
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from itertools import product
from typing import Callable, Optional
import pandas as pd
import torch
......@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights,
)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
......@@ -63,23 +62,23 @@ class BenchmarkTensors:
a: torch.Tensor
w_q: torch.Tensor
group_size: Optional[int]
group_size: int | None
wtype: ScalarType
w_g_s: torch.Tensor
w_g_zp: Optional[torch.Tensor]
w_ch_s: Optional[torch.Tensor]
w_tok_s: Optional[torch.Tensor]
w_g_zp: torch.Tensor | None
w_ch_s: torch.Tensor | None
w_tok_s: torch.Tensor | None
@dataclass
class TypeConfig:
act_type: torch.dtype
weight_type: ScalarType
output_type: Optional[torch.dtype]
group_scale_type: Optional[torch.dtype]
group_zero_type: Optional[torch.dtype]
channel_scale_type: Optional[torch.dtype]
token_scale_type: Optional[torch.dtype]
output_type: torch.dtype | None
group_scale_type: torch.dtype | None
group_zero_type: torch.dtype | None
channel_scale_type: torch.dtype | None
token_scale_type: torch.dtype | None
def rand_data(shape, dtype=torch.float16, scale=1):
......@@ -93,8 +92,8 @@ def quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
stype: torch.dtype | None,
group_size: int | None,
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
......@@ -113,7 +112,7 @@ def quantize_and_pack(
def create_bench_tensors(
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
) -> list[BenchmarkTensors]:
m, n, k = shape
......@@ -331,8 +330,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
return res
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
_SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None
_SWEEP_SCHEDULES_RESULTS_CSV: str | None = None
def bench(
......
......@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
sort_weights,
)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
......
......@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
......@@ -211,7 +211,7 @@ def get_rocm_tuning_space(use_fp16):
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
num_stage_range = [2]
waves_per_eu_range = [0]
waves_per_eu_range = [0, 1, 2, 4]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else []
......@@ -579,19 +579,23 @@ def main(args: argparse.Namespace):
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"Glm4MoeForCausalLM",
"NemotronHForCausalLM",
):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in (
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
......@@ -600,10 +604,23 @@ def main(args: argparse.Namespace):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
text_config = config.get_text_config()
E = text_config.num_experts
topk = text_config.num_experts_per_tok
intermediate_size = text_config.moe_intermediate_size
hidden_size = text_config.hidden_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
E = config.num_experts
topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0]
hidden_size = config.hidden_size
elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
E = config.thinker_config.text_config.num_experts
topk = config.thinker_config.text_config.num_experts_per_tok
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
hidden_size = config.thinker_config.text_config.hidden_size
else:
# Support for llama4
config = config.get_text_config()
......@@ -611,6 +628,7 @@ def main(args: argparse.Namespace):
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
enable_ep = bool(args.enable_expert_parallel)
if enable_ep:
ensure_divisibility(E, args.tp_size, "Number of experts")
......@@ -619,8 +637,7 @@ def main(args: argparse.Namespace):
else:
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)
......
......@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
......@@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute
......
......@@ -39,7 +39,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......
......@@ -3,16 +3,15 @@
import random
import time
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)
......@@ -37,7 +36,7 @@ def main(
seed: int,
do_profile: bool,
device: str = "cuda",
kv_cache_dtype: Optional[str] = None,
kv_cache_dtype: str | None = None,
) -> None:
current_platform.seed_everything(seed)
......
......@@ -3,8 +3,8 @@
import argparse
import math
from collections.abc import Callable
from contextlib import contextmanager
from typing import Callable
from unittest.mock import patch
import torch
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import torch
from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
def polynorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
def norm(x, eps: float):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.float()
return (
(
weight[0] * norm(x**3, eps)
+ weight[1] * norm(x**2, eps)
+ weight[2] * norm(x, eps)
+ bias
)
.to(weight.dtype)
.view(orig_shape)
)
def polynorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
out = torch.empty_like(x)
vllm_ops.poly_norm(out, x, weight, bias, eps)
output = out
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_dim):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
output_naive = polynorm_naive(x, weight, bias)
output_vllm = polynorm_vllm(x, weight, bias)
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
dim_range = [2048, 4096]
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
def get_benchmark():
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["dim", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "vllm"],
line_names=["Naive", "vLLM"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="polynorm-perf",
args={},
)
)
def benchmark(dim, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_dim = dim * 4
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_naive(x, weight, bias),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_vllm(x, weight, bias),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size",
)
parser.add_argument(
"--seq-len",
type=int,
default=128,
help="Sequence length",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=8192,
help="Intermediate size of MLP",
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/polnorm/",
help="Path to save polnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=args.batch_size,
seq_len=args.seq_len,
hidden_dim=args.hidden_dim,
)
benchmark = get_benchmark()
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
......@@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import time
import torch
from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
create_kv_caches_with_random,
)
logger = init_logger(__name__)
@torch.inference_mode()
def run_benchmark(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
kv_cache_dtype: str,
num_iters: int,
benchmark_mode: str,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""
if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
current_platform.seed_everything(42)
torch.set_default_device(device)
# create random key / value tensors [T, H, D].
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
value = torch.randn_like(key)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots = block_size * num_blocks
if num_tokens > num_slots:
raise ValueError("num_tokens cannot exceed the total number of cache slots")
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
key_caches, value_caches = create_kv_caches_with_random(
num_blocks,
block_size,
1, # num_layers
num_heads,
head_size,
kv_cache_dtype,
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
function_under_test = lambda: ops.reshape_and_cache(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
function_under_test()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters
# warm-up
run_cuda_benchmark(3)
lat = run_cuda_benchmark(num_iters)
# free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache()
return lat
def main(args):
rows = []
for exp in range(1, 17):
n_tok = 2**exp
lat = run_benchmark(
num_tokens=n_tok,
num_heads=args.num_heads,
head_size=args.head_size,
block_size=args.block_size,
num_blocks=args.num_blocks,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
kv_cache_dtype=args.kv_cache_dtype,
num_iters=args.iters,
benchmark_mode=args.mode,
device="cuda",
)
rows.append([n_tok, lat * 1e6]) # convert to microseconds
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--num-heads", type=int, default=128)
parser.add_argument(
"--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--num-blocks", type=int, default=128 * 128)
parser.add_argument(
"--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="bfloat16",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
default="auto",
)
parser.add_argument("--iters", type=int, default=200)
parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random
import time
......@@ -14,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random_flash,
)
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Optional, Union
import torch
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
......@@ -21,8 +20,8 @@ class HuggingFaceRMSNorm(nn.Module):
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
......@@ -41,7 +40,7 @@ class HuggingFaceRMSNorm(nn.Module):
def rmsnorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
residual: torch.Tensor | None = None,
eps: float = 1e-6,
):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
......@@ -65,7 +64,7 @@ def rmsnorm_naive(
def rmsnorm_flashinfer(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
residual: torch.Tensor | None = None,
eps: float = 1e-6,
):
orig_shape = x.shape
......@@ -89,7 +88,7 @@ def rmsnorm_flashinfer(
def rmsnorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
residual: torch.Tensor | None = None,
eps: float = 1e-6,
):
orig_shape = x.shape
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import accumulate
from typing import Optional
import itertools
import nvtx
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
batch_size_range = [2**i for i in range(0, 8, 2)]
seq_len_range = [2**i for i in range(6, 10, 1)]
num_heads_range = [32, 48]
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
def benchmark_rope_kernels_multi_lora(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: float = 10000,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
# silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors
batched_rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
{"rope_type": "linear", "factor": tuple(scaling_factors)},
def get_benchmark(head_size, rotary_dim, is_neox_style, device):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "num_heads"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["torch", "flashinfer", "vllm"],
line_names=["PyTorch", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
args={},
)
)
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes: list[RotaryEmbedding] = []
for scaling_factor in scaling_factors:
non_batched_ropes.append(
get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
{"rope_type": "linear", "factor": (scaling_factor,)},
)
def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16
max_position = 8192
base = 10000
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope = rope.to(dtype=dtype, device=device)
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
query = torch.randn(
(batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
)
key = torch.randn_like(query)
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
key = torch.randn_like(query)
quantiles = [0.5, 0.2, 0.8]
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map = torch.tensor(
list(
accumulate(
[0]
+ [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
]
if provider == "torch":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rope.forward_native(positions, query.clone(), key.clone()),
quantiles=quantiles,
)
)
)
query_types = torch.randint(
0, len(scaling_factors), (batch_size, seq_len), device=device
)
# map query types to offsets
query_offsets = offset_map[query_types]
# the kernel takes flattened offsets
flatten_offsets = query_offsets.flatten()
elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.ops.vllm.flashinfer_rotary_embedding(
positions,
query.clone(),
key.clone(),
head_size,
cos_sin_cache,
is_neox_style,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
# batched queries of the same type together for non-batched RoPE
queries = [query[query_types == i] for i in range(len(scaling_factors))]
keys = [key[query_types == i] for i in range(len(scaling_factors))]
packed_qkr = zip(queries, keys, non_batched_ropes)
# synchronize before start timing
torch.cuda.synchronize()
with nvtx.annotate("non-batched", color="yellow"):
for q, k, r in packed_qkr:
r.forward(positions, q, k)
torch.cuda.synchronize()
with nvtx.annotate("batched", color="green"):
batched_rope.forward(positions, query, key, flatten_offsets)
torch.cuda.synchronize()
return benchmark
if __name__ == "__main__":
......@@ -117,17 +95,12 @@ if __name__ == "__main__":
parser.add_argument(
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
)
parser.add_argument("--save-path", type=str, default="./configs/rope/")
args = parser.parse_args()
print(args)
benchmark_rope_kernels_multi_lora(
is_neox_style=args.is_neox_style,
batch_size=args.batch_size,
seq_len=args.seq_len,
num_heads=args.num_heads,
head_size=args.head_size,
rotary_dim=args.rotary_dim,
dtype=getattr(torch, args.dtype),
seed=args.seed,
device=args.device,
# Get the benchmark function
benchmark = get_benchmark(
args.head_size, args.rotary_dim, args.is_neox_style, args.device
)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
......@@ -78,11 +78,11 @@ WEIGHT_SHAPES = {
}
WEIGHT_SHAPES_MOE = {
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
"mistralai/Mixtral-8x7B-Instruct-v0.1": [
[8, 2, 4096, 28672],
[8, 2, 14336, 4096],
],
"nm-testing/deepseekv2-lite": [
"deepseek-ai/DeepSeek-V2-Lite": [
[64, 6, 2048, 1408],
],
"ibm-granite/granite-3.0-1b-a400m": [
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Comprehensive 3-way SiLU Benchmark Suite
This benchmark compares three SiLU implementations:
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
2. Triton Kernel - Triton-based implementation
The suite generates detailed performance comparisons including:
- Memory bandwidth utilization
- Speedup ratios (baseline vs optimized implementations)
- Performance across different expert configurations and token distributions
"""
from collections.abc import Callable
import matplotlib.pyplot as plt
......@@ -7,7 +21,7 @@ import numpy as np
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm_cuda,
persistent_masked_m_silu_mul_quant,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
......@@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
num_parallel_tokens,
group_size: int = 128,
eps: float = 1e-10,
expert_offsets: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
......@@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
# Parse generation strategies
strategies = ["uniform", "max_t", "first_t"]
strategies = ["random_imbalanced", "uniform", "max_t"]
def benchmark(
......@@ -195,15 +210,27 @@ def benchmark(
current_platform.seed_everything(42 + seed_offset)
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
if gen_strategy == "uniform":
r = torch.rand(size=(E,), device="cuda")
if gen_strategy == "random_imbalanced":
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
mean = total_tokens // n_e
min_max = mean // ratio
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
e[0] = min_max
r = torch.rand(size=(E - 1,))
r /= r.sum()
r *= total_tokens - min_max
r = r.round().long()
e[1:] = r.to(device=device)
return e
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
elif gen_strategy == "uniform":
r = torch.rand(size=(E,))
r /= r.sum()
r *= total_tokens
tokens_per_expert = r.int()
tokens_per_expert = torch.minimum(
tokens_per_expert,
torch.ones((E,), device=r.device, dtype=torch.int) * T,
)
r = r.round().long()
tokens_per_expert = r
elif gen_strategy == "max_t":
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
tokens_per_expert.fill_(total_tokens / E)
......@@ -281,40 +308,34 @@ def benchmark(
def create_comparison_plot(
ratio, cuda_times, baseline_times, config_labels, strategy_name, id
ratios, silu_v2_times, triton_times, config_labels, strategy_name, id
):
"""Create a comparison plot for a specific generation strategy"""
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
# Configure x-axis positions
x = np.arange(len(config_labels))
width = 0.35
width = 0.25
# Execution Time plot (lower is better)
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
ax.bar(
x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue"
)
ax.bar(
x + width / 2,
baseline_times,
width,
label="Baseline",
alpha=0.8,
color="orange",
x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green"
)
# Add speedup labels over each bar pair
# Add speedup labels over each bar trio
for i in range(len(x)):
speedup = ratio[i]
max_height = max(cuda_times[i], baseline_times[i])
triton_v2_speedup = ratios[i][1] # triton/v2
max_height = max(silu_v2_times[i], triton_times[i])
# Triton/V2 speedup
ax.text(
x[i],
x[i] + width / 2,
max_height + max_height * 0.02,
f"{speedup:.2f}x",
f"{triton_v2_speedup:.2f}x",
ha="center",
va="bottom",
fontweight="bold",
fontsize=9,
fontsize=8,
)
ax.set_xlabel("Configuration")
......@@ -332,56 +353,75 @@ def create_comparison_plot(
def create_combined_plot(all_results):
"""Create a combined plot with all strategies in one PNG"""
num_strategies = len(all_results)
fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies))
fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies))
if num_strategies == 1:
axes = [axes]
for idx, (
strategy_name,
ratio,
cuda_times,
baseline_times,
all_ratios,
all_silu_v2_results,
all_triton_results,
config_labels,
config_x_axis,
) in enumerate(all_results):
ax = axes[idx]
# Flatten the nested results to get bandwidth percentages for plotting
silu_v2_bandwidths = []
triton_bandwidths = []
flat_ratios = []
for config_results in all_silu_v2_results:
for result in config_results:
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
for config_results in all_triton_results:
for result in config_results:
triton_bandwidths.append(result[3]) # bandwidth percentage
for config_ratios in all_ratios:
for ratio in config_ratios:
flat_ratios.append(ratio)
# Configure x-axis positions
x = np.arange(len(config_labels))
width = 0.35
width = 0.25
# Execution Time plot (lower is better)
# Bandwidth utilization plot (higher is better)
ax.bar(
x - width / 2,
cuda_times,
x,
silu_v2_bandwidths,
width,
label="CUDA Kernel",
label="SiLU V2 (CUDA)",
alpha=0.8,
color="blue",
)
ax.bar(
x + width / 2,
baseline_times,
x + width,
triton_bandwidths,
width,
label="Baseline",
label="Triton Kernel",
alpha=0.8,
color="orange",
color="green",
)
# Add speedup labels over each bar pair
# Add speedup labels over each bar trio
for i in range(len(x)):
speedup = ratio[i]
max_height = max(cuda_times[i], baseline_times[i])
triton_v2_speedup = flat_ratios[i] # triton/v2
max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i])
# Triton/V2 speedup
ax.text(
x[i],
x[i] + width / 2,
max_height + max_height * 0.02,
f"{speedup:.2f}x",
f"{triton_v2_speedup:.2f}x",
ha="center",
va="bottom",
fontweight="bold",
fontsize=9,
fontsize=8,
)
ax.set_xlabel("Configuration")
......@@ -395,7 +435,7 @@ def create_combined_plot(all_results):
ax.grid(True, alpha=0.3)
plt.tight_layout()
filename = "../../silu_bench/silu_benchmark_combined.png"
filename = "silu_benchmark_combined_3way.png"
plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.show()
......@@ -405,7 +445,9 @@ def create_combined_plot(all_results):
outer_dim = 7168
configs = [
# DeepSeekV3 Configs
# (1, 56, 7168),
(8, 1024, 7168),
# (32, 56, 7168),
# DeepSeekV3 Configs
(32, 1024, 7168),
# DeepSeekV3 Configs
......@@ -417,6 +459,7 @@ num_warmups = 20
strategy_descriptions = {
"uniform": "Uniform Random",
"random_imbalanced": "Imbalanced Random",
"max_t": "Even Assignment",
"first_t": "experts[0] = T, experts[1:] = 0",
}
......@@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies):
print(f"Testing strategy: {strategy_descriptions[strategy]}")
print(f"{'=' * 60}")
# Collect benchmark data for both algorithms
# Collect benchmark data for all three algorithms
config_labels = []
config_x_axis = []
all_cuda_results = []
all_baseline_results = []
all_silu_v2_results = []
all_triton_results = []
all_ratios = []
for E, T, H in configs:
total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E]
total_tokens_config = []
for i in [8, 16, 32, 64, 128, 256, 512]:
if i <= T:
total_tokens_config.append(i * E)
config_x_axis.append(total_tokens_config)
cuda_results = []
baseline_results = []
silu_v2_results = []
triton_results = []
ratios = []
for total_tokens in total_tokens_config:
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
config_labels.append(config_label)
# CUDA kernel results
time_ms_cuda, gflops, gbps, perc = benchmark(
silu_mul_fp8_quant_deep_gemm_cuda,
# SiLU V2 (CUDA kernel) results
time_ms_silu_v2, gflops, gbps, perc = benchmark(
persistent_masked_m_silu_mul_quant,
E,
T,
H,
......@@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies):
num_warmups=num_warmups,
gen_strategy=strategy,
)
cuda_results.append((time_ms_cuda, gflops, gbps, perc))
silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc))
# Baseline results
# Triton kernel results
time_ms_triton, gflops, gbps, perc = benchmark(
silu_mul_fp8_quant_deep_gemm_triton,
E,
......@@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies):
num_warmups=num_warmups,
gen_strategy=strategy,
)
baseline_results.append((time_ms_triton, gflops, gbps, perc))
ratios.append(time_ms_triton / time_ms_cuda)
triton_results.append((time_ms_triton, gflops, gbps, perc))
print(f"Completed: {config_label}")
all_cuda_results.append(cuda_results)
all_baseline_results.append(baseline_results)
# Calculate speedup ratios (triton baseline / implementation)
triton_v2_ratio = time_ms_triton / time_ms_silu_v2
ratios.append(triton_v2_ratio)
print(
f"Completed: {config_label}:"
f" V2: {time_ms_silu_v2:.3f}ms,"
f" Triton: {time_ms_triton:.3f}ms"
)
all_silu_v2_results.append(silu_v2_results)
all_triton_results.append(triton_results)
all_ratios.append(ratios)
# Store results for combined plotting
......@@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies):
(
strategy_descriptions[strategy],
all_ratios,
all_cuda_results,
all_baseline_results,
all_silu_v2_results,
all_triton_results,
config_labels,
config_x_axis,
)
......@@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies):
# Print summary table for this strategy
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}")
print("-" * 60)
print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}")
print("-" * 90)
for i, (E, T, H) in enumerate(configs):
speedup = baseline_results[i][0] / cuda_results[i][0]
# Get the first result for each config (simplifying for summary)
v2_time = silu_v2_results[i][0]
triton_time = triton_results[i][0]
triton_v2_speedup = triton_time / v2_time
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
print(
f"{config_label:<20} {cuda_results[i][0]:8.5f} "
f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x"
f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} "
f"{triton_v2_speedup:8.2f}x"
)
......@@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results):
num_strategies = len(all_results)
num_configs = len(configs)
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
fig, axs = plt.subplots(
num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies)
num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies)
)
# Add main title to the entire figure
fig.suptitle(
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)",
fontsize=16,
"Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)",
fontsize=18,
fontweight="bold",
y=0.98,
)
......@@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results):
(
strategy_name,
all_ratios,
all_cuda_results,
all_baseline_results,
all_silu_v2_results,
all_triton_results,
config_labels,
config_x_axis,
) = result
......@@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results):
ratios = all_ratios[config_idx]
total_tokens_values = config_x_axis[config_idx]
# Extract CUDA and Triton bandwidth percentages
cuda_bandwidth_percentages = [
result[3] for result in all_cuda_results[config_idx]
# Extract speedup ratios
triton_v2_ratios = [ratio for ratio in ratios]
# Extract bandwidth percentages for all implementations
v2_bandwidth_percentages = [
result[3] for result in all_silu_v2_results[config_idx]
]
triton_bandwidth_percentages = [
result[3] for result in all_baseline_results[config_idx]
result[3] for result in all_triton_results[config_idx]
]
# Plot speedup ratios vs total tokens (left plot)
ax_speedup.plot(
total_tokens_values, ratios, "bo-", linewidth=3, markersize=8
total_tokens_values,
triton_v2_ratios,
"go-",
linewidth=3,
markersize=8,
label="Triton/V2 Speedup",
)
ax_speedup.set_title(
f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}",
f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}",
fontsize=12,
fontweight="bold",
)
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
ax_speedup.legend(prop={"weight": "bold"})
ax_speedup.grid(True, alpha=0.3)
# Plot bandwidth utilization (right plot)
ax_bandwidth.plot(
total_tokens_values,
cuda_bandwidth_percentages,
"ro-",
v2_bandwidth_percentages,
"o-",
linewidth=3,
markersize=8,
label="CUDA",
label="SiLU V2",
color="blue",
)
ax_bandwidth.plot(
total_tokens_values,
triton_bandwidth_percentages,
"go-",
"o-",
linewidth=3,
markersize=8,
label="Triton",
color="green",
)
ax_bandwidth.set_title(
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
......@@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results):
for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_fontweight("bold")
# Add value labels on speedup points
for x, y in zip(total_tokens_values, ratios):
# Add value labels on Triton/V2 speedup points
for x, y in zip(total_tokens_values, triton_v2_ratios):
ax_speedup.annotate(
f"{y:.2f}x",
(x, y),
textcoords="offset points",
xytext=(0, 12),
ha="center",
fontsize=10,
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
)
# Add value labels on CUDA bandwidth points
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
ax_bandwidth.annotate(
f"{y:.1f}%",
(x, y),
textcoords="offset points",
xytext=(0, 12),
ha="center",
fontsize=9,
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
)
# Add value labels on Triton bandwidth points
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
ax_bandwidth.annotate(
f"{y:.1f}%",
(x, y),
textcoords="offset points",
xytext=(0, -15),
ha="center",
fontsize=9,
......@@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results):
plt.tight_layout()
plt.subplots_adjust(top=0.93) # Make room for main title
filename = "silu_benchmark_total_tokens.png"
filename = "silu_benchmark_total_tokens_3way.png"
plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.show()
return filename
# Create combined plot with all strategies
combined_plot_filename = create_total_tokens_plot(all_results)
# Create comprehensive 3-way comparison plots
combined_plot_filename = create_combined_plot(all_results)
total_tokens_plot_filename = create_total_tokens_plot(all_results)
print(f"\n{'=' * 60}")
print("Benchmark Complete!")
print(f"Generated combined plot: {combined_plot_filename}")
print(f"{'=' * 60}")
print(f"\n{'=' * 80}")
print("3-Way Benchmark Suite Complete!")
print(f"Generated combined comparison plot: {combined_plot_filename}")
print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}")
print("Compared: SiLU V2 (CUDA), and Triton implementations")
print(f"{'=' * 80}")
......@@ -4,12 +4,11 @@
import csv
import os
from datetime import datetime
from typing import Optional
import flashinfer
import torch
from vllm.utils import round_up
from vllm.utils.math_utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn
......@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad()
def benchmark_decode(
dtype: torch.dtype,
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
batch_size: int,
max_seq_len: int,
num_heads: tuple[int, int] = (64, 8),
......
......@@ -4,12 +4,11 @@
import csv
import os
from datetime import datetime
from typing import Optional
import flashinfer
import torch
from vllm.utils import round_up
from vllm.utils.math_utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn
......@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad()
def benchmark_prefill(
dtype: torch.dtype,
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
batch_size: int,
max_seq_len: int,
num_heads: tuple[int, int] = (64, 8),
......
......@@ -14,11 +14,11 @@ import torch
from tqdm import tqdm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul,
_w8a8_triton_block_scaled_mm,
)
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True)
......@@ -83,7 +83,7 @@ def w8a8_block_matmul(
)
if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_block_fp8_matmul
kernel = _w8a8_triton_block_scaled_mm
else:
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment