Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode() @torch.inference_mode()
......
...@@ -6,11 +6,12 @@ import copy ...@@ -6,11 +6,12 @@ import copy
import json import json
import pickle import pickle
import time import time
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional from typing import Any
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
...@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement ...@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
from utils import ArgPool, Bench, CudaGraphBenchParams from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES from weight_shapes import WEIGHT_SHAPES
from vllm.triton_utils import HAS_TRITON from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.triton_utils import HAS_TRITON, triton
if HAS_TRITON: if HAS_TRITON:
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora
LoRAKernelMeta,
fused_moe_lora_expand,
fused_moe_lora_shrink,
lora_expand,
lora_shrink,
)
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
_LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora
)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.math_utils import round_up
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_TP_SIZES = [1] DEFAULT_TP_SIZES = [1]
...@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4] ...@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS = [False, True] DEFAULT_SORT_BY_LORA_IDS = [False, True]
DEFAULT_SEQ_LENGTHS = [1] DEFAULT_SEQ_LENGTHS = [1]
DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]
DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k
DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts
# Utilities # Utilities
...@@ -158,7 +172,7 @@ def ref_group_gemm( ...@@ -158,7 +172,7 @@ def ref_group_gemm(
seq_lens_cpu: torch.Tensor, seq_lens_cpu: torch.Tensor,
prompt_lora_mapping_cpu: torch.Tensor, prompt_lora_mapping_cpu: torch.Tensor,
scaling: float, scaling: float,
add_inputs: Optional[bool], add_inputs: bool | None,
): ):
""" """
Torch group gemm reference implementation to test correctness of Torch group gemm reference implementation to test correctness of
...@@ -190,6 +204,11 @@ class OpType(Enum): ...@@ -190,6 +204,11 @@ class OpType(Enum):
LORA_SHRINK = auto() LORA_SHRINK = auto()
LORA_EXPAND = auto() LORA_EXPAND = auto()
## Adding support for fused moe lora
FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink
FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand
FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink
FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand
@staticmethod @staticmethod
def from_str(s: str) -> "OpType": def from_str(s: str) -> "OpType":
...@@ -197,6 +216,15 @@ class OpType(Enum): ...@@ -197,6 +216,15 @@ class OpType(Enum):
return OpType.LORA_SHRINK return OpType.LORA_SHRINK
if s.lower() == "lora_expand": if s.lower() == "lora_expand":
return OpType.LORA_EXPAND return OpType.LORA_EXPAND
# Adding support for fused moe lora, both in gate_up and down
if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink
return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK
if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand
return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND
if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink
return OpType.FUSED_MOE_LORA_DOWN_SHRINK
if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand
return OpType.FUSED_MOE_LORA_DOWN_EXPAND
raise ValueError(f"Unrecognized str {s} to convert to OpType") raise ValueError(f"Unrecognized str {s} to convert to OpType")
def is_shrink_fn(self) -> bool: def is_shrink_fn(self) -> bool:
...@@ -205,19 +233,56 @@ class OpType(Enum): ...@@ -205,19 +233,56 @@ class OpType(Enum):
def is_expand_fn(self) -> bool: def is_expand_fn(self) -> bool:
return self in [OpType.LORA_EXPAND] return self in [OpType.LORA_EXPAND]
def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def is_fused_moe_lora_gate_up_fn(
self,
) -> bool: ## adding for fused MoE LoRA Gate/Up
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
]
def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down
return self in [
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def is_fused_moe_lora_shrink_fn(self) -> bool:
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
]
def is_fused_moe_lora_expand_fn(self) -> bool:
return self in [
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]
def num_slices(self) -> list[int]: def num_slices(self) -> list[int]:
if self.is_fused_moe_lora_gate_up_fn():
return [2]
elif self.is_fused_moe_lora_down_fn():
return [1]
return [1, 2, 3] return [1, 2, 3]
def mkn( def mkn(
self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
) -> tuple[int, int, int]: ) -> tuple[int, int, int]:
num_tokens = batch_size * seq_length num_tokens = batch_size * seq_length
if self.is_shrink_fn(): if self.is_shrink_fn() or self.is_fused_moe_lora_fn():
m = num_tokens m = num_tokens
k = hidden_size k = hidden_size
n = lora_rank n = lora_rank
else: elif self.is_expand_fn():
assert self.is_expand_fn()
m = num_tokens m = num_tokens
k = lora_rank k = lora_rank
n = hidden_size n = hidden_size
...@@ -231,9 +296,36 @@ class OpType(Enum): ...@@ -231,9 +296,36 @@ class OpType(Enum):
""" """
if self.is_shrink_fn(): if self.is_shrink_fn():
return op_dtype, op_dtype, torch.float32 return op_dtype, op_dtype, torch.float32
else: elif self.is_expand_fn():
assert self.is_expand_fn()
return torch.float32, op_dtype, op_dtype return torch.float32, op_dtype, op_dtype
else:
assert self.is_fused_moe_lora_fn()
return op_dtype, op_dtype, op_dtype
def matmul_shapes_fused_moe_lora(
self,
m: int,
n: int,
k: int,
num_loras: int,
num_slices: int,
top_k_num: int,
num_experts: int,
) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]:
if self.is_fused_moe_lora_shrink_fn():
input_shape = (
(m * top_k_num, n)
if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK]
else (m, n)
)
output_shape = (num_slices, m, top_k_num, k)
weight_shape = (num_loras, num_experts, k, n)
else:
assert self.is_fused_moe_lora_expand_fn()
input_shape = (num_slices, m, top_k_num, k)
output_shape = (m, top_k_num, n * num_slices)
weight_shape = (num_loras, num_experts, n, k)
return (input_shape, weight_shape, output_shape)
def matmul_shapes( def matmul_shapes(
self, self,
...@@ -243,6 +335,8 @@ class OpType(Enum): ...@@ -243,6 +335,8 @@ class OpType(Enum):
lora_rank: int, lora_rank: int,
num_loras: int, num_loras: int,
num_slices: int, num_slices: int,
top_k_num: int | None = None,
num_experts: int | None = None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
""" """
Given num_slices, return the shapes of the A, B, and C matrices Given num_slices, return the shapes of the A, B, and C matrices
...@@ -257,6 +351,16 @@ class OpType(Enum): ...@@ -257,6 +351,16 @@ class OpType(Enum):
if self in [OpType.LORA_EXPAND]: if self in [OpType.LORA_EXPAND]:
# LoRA expand kernels support num_slices inherently in the kernel # LoRA expand kernels support num_slices inherently in the kernel
return ((num_slices, m, k), b_shape, (m, n * num_slices)) return ((num_slices, m, k), b_shape, (m, n * num_slices))
if self.is_fused_moe_lora_fn():
return self.matmul_shapes_fused_moe_lora(
m,
k,
n,
num_loras,
num_slices,
top_k_num,
num_experts,
)
raise ValueError(f"Unrecognized op_type {self}") raise ValueError(f"Unrecognized op_type {self}")
def bench_fn(self) -> Callable: def bench_fn(self) -> Callable:
...@@ -264,6 +368,16 @@ class OpType(Enum): ...@@ -264,6 +368,16 @@ class OpType(Enum):
return lora_shrink return lora_shrink
if self == OpType.LORA_EXPAND: if self == OpType.LORA_EXPAND:
return lora_expand return lora_expand
if self in [
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
]:
return fused_moe_lora_shrink
if self in [
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
]:
return fused_moe_lora_expand
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
...@@ -316,8 +430,10 @@ class BenchmarkContext: ...@@ -316,8 +430,10 @@ class BenchmarkContext:
lora_rank: int lora_rank: int
sort_by_lora_id: bool sort_by_lora_id: bool
dtype: torch.dtype dtype: torch.dtype
seq_length: Optional[int] = None seq_length: int | None = None
num_slices: Optional[int] = None # num_slices for slice based ops num_experts: int | None = None # num_experts for MoE based ops
top_k_num: int | None = None # top_k for MoE based ops
num_slices: int | None = None # num_slices for slice based ops
def with_seq_length(self, seq_length: int) -> "BenchmarkContext": def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
ctx = copy.copy(self) ctx = copy.copy(self)
...@@ -372,6 +488,11 @@ class BenchmarkTensors: ...@@ -372,6 +488,11 @@ class BenchmarkTensors:
f"{dtype_to_str(self.output.dtype)}" f"{dtype_to_str(self.output.dtype)}"
) )
def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType):
return (
size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size
)
@staticmethod @staticmethod
def make( def make(
ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
...@@ -384,6 +505,8 @@ class BenchmarkTensors: ...@@ -384,6 +505,8 @@ class BenchmarkTensors:
ctx.lora_rank, ctx.lora_rank,
ctx.num_loras, ctx.num_loras,
ctx.num_slices, ctx.num_slices,
ctx.top_k_num,
ctx.num_experts,
) )
a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
input_tensor, lora_weights, output_tensor = make_rand_tensors( input_tensor, lora_weights, output_tensor = make_rand_tensors(
...@@ -431,17 +554,27 @@ class BenchmarkTensors: ...@@ -431,17 +554,27 @@ class BenchmarkTensors:
prompt_lora_indices_tensor, prompt_lora_indices_tensor,
) )
def sanity_check(self) -> None: def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None:
""" """
Fails asserts when non-conformality is detected. Fails asserts when non-conformality is detected.
""" """
num_tokens = self.input.shape[-2] num_tokens = (
self.input.shape[1]
if op_type.is_fused_moe_lora_expand_fn()
else self.input.shape[-2]
)
# check metadata tensors # check metadata tensors
assert torch.sum(self.seq_lens) == num_tokens ## In down shrink case, each token is repeated top_k_num times
assert num_tokens == self.get_num_tokens(
torch.sum(self.seq_lens), ctx.top_k_num, op_type
), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}"
num_seqs = self.seq_lens.shape[0] num_seqs = self.seq_lens.shape[0]
# assert self.seq_start_loc.shape[0] == num_seqs # assert self.seq_start_loc.shape[0] == num_seqs
## In down shrink case, each prompt corresponds to top_k_num sequences
assert self.prompt_lora_mapping.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens assert self.get_num_tokens(
self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type
)
def to_device(self, device: str): def to_device(self, device: str):
""" """
...@@ -470,21 +603,111 @@ class BenchmarkTensors: ...@@ -470,21 +603,111 @@ class BenchmarkTensors:
to_device(field) if field_name != "no_lora_flag_cpu" else field, to_device(field) if field_name != "no_lora_flag_cpu" else field,
) )
def metadata(self) -> tuple[int, int, int]: def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]:
""" """
Return num_seqs, num_tokens and max_seq_len Return num_seqs, num_tokens and max_seq_len
""" """
num_seqs = self.seq_lens.shape[0] num_seqs = self.seq_lens.shape[0]
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0] num_tokens = self.get_num_tokens(
self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type
)
max_seq_len = torch.max(self.seq_lens).item() max_seq_len = torch.max(self.seq_lens).item()
num_slices = len(self.lora_weights_lst) num_slices = len(self.lora_weights_lst)
return num_seqs, num_tokens, max_seq_len, num_slices return num_seqs, num_tokens, max_seq_len, num_slices
def as_lora_shrink_kwargs(self) -> dict[str, Any]: def fused_moe_lora_data_prepare(
self.sanity_check() self,
block_size: int,
token_lora_mapping: torch.Tensor,
ctx: BenchmarkContext,
):
def moe_lora_align_block_size(
topk_ids: torch.Tensor,
token_lora_mapping: torch.Tensor,
block_size: int,
num_experts: int,
max_loras: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
num_tokens = ctx.batch_size
curr_topk_ids = torch.randint(
0,
ctx.num_experts,
(num_tokens, ctx.top_k_num),
device="cuda",
dtype=torch.int32,
)
topk_weights = torch.randint(
0,
ctx.num_experts,
(num_tokens, ctx.top_k_num),
device="cuda",
dtype=torch.int32,
)
(sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = (
moe_lora_align_block_size(
topk_ids=curr_topk_ids,
token_lora_mapping=token_lora_mapping,
block_size=block_size,
num_experts=ctx.num_experts,
max_loras=ctx.num_loras,
)
)
sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1)
expert_ids = expert_ids_lora.view(ctx.num_loras, -1)
num_tokens_post_padded = num_tokens_post_padded_lora
return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded)
def as_lora_shrink_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device) self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata() _, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes. # Sanity check matrix shapes.
i_shape, lw_shape, o_shape = ( i_shape, lw_shape, o_shape = (
...@@ -519,11 +742,13 @@ class BenchmarkTensors: ...@@ -519,11 +742,13 @@ class BenchmarkTensors:
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
} }
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: def as_lora_expand_kwargs(
self.sanity_check() self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device) self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata() _, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes. # Sanity check matrix shapes.
i_shape, lw_shape, o_shape = ( i_shape, lw_shape, o_shape = (
...@@ -560,22 +785,177 @@ class BenchmarkTensors: ...@@ -560,22 +785,177 @@ class BenchmarkTensors:
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
} }
def as_fused_moe_lora_shrink_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = (
self.input.shape,
self.lora_weights_lst[0].shape,
self.output.shape,
)
# Expected input shape : [num_tokens, hidden_size] for gate_up
# Expected input shape : [top_k_num * num_tokens, hidden_size] for down
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size]
assert len(lw_shape) == 4
assert lw_shape[-1] == hidden_size
lora_rank = lw_shape[-2]
# Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert len(o_shape) == 4
assert (
o_shape
== (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank)
if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK]
else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank)
)
kernel_config = get_lora_op_configs(
op_type.name.lower(),
max_loras=lw_shape[0],
batch=num_tokens,
hidden_size=hidden_size,
rank=lora_rank,
num_slices=num_slices,
add_inputs=False,
)
(topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = (
self.fused_moe_lora_data_prepare(
block_size=kernel_config["BLOCK_SIZE_M"],
token_lora_mapping=self.lora_kernel_meta.token_lora_mapping,
ctx=ctx,
)
)
return {
"qcurr_hidden_states": self.input,
"lora_a_stacked": self.lora_weights_lst,
"a_intermediate_cache1": self.output,
"topk_weights": topk_weights,
"sorted_token_ids": sorted_token_ids,
"expert_ids": expert_ids,
"num_tokens_post_padded": num_tokens_post_padded,
"top_k_num": ctx.top_k_num,
"device": self.input.device,
"N": lora_rank,
"M": topk_weights.shape[0],
"EM": sorted_token_ids.shape[1],
"K": self.input.shape[1],
"num_tokens": num_tokens,
"num_experts": ctx.num_experts,
"num_slices": num_slices,
"shrink_block_size_m": kernel_config["BLOCK_SIZE_M"],
"shrink_block_size_n": kernel_config["BLOCK_SIZE_N"],
"shrink_block_size_k": kernel_config["BLOCK_SIZE_K"],
"shrink_group_size_m": kernel_config["GROUP_SIZE_M"],
"shrink_num_warps": kernel_config["NUM_WARPS"],
"shrink_num_stages": kernel_config["NUM_STAGES"],
"shrink_split_k": kernel_config.get("SPLIT_K", 1),
"mul_routed_weight": op_type.is_fused_moe_lora_down_fn(),
}
def as_fused_moe_lora_expand_kwargs(
self, ctx: BenchmarkContext, op_type: OpType
) -> dict[str, Any]:
self.sanity_check(ctx, op_type)
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = (
self.input.shape,
self.lora_weights_lst[0].shape,
self.output.shape,
)
# Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert len(i_shape) == 4
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[-1]
# Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank]
assert len(lw_shape) == 4
assert lw_shape[-1] == lora_rank
hidden_size = lw_shape[-2]
# Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices]
assert len(o_shape) == 3
assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices)
kernel_config = get_lora_op_configs(
op_type.name.lower(),
max_loras=lw_shape[0],
batch=num_tokens,
hidden_size=hidden_size,
rank=lora_rank,
num_slices=num_slices,
add_inputs=False,
)
(topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = (
self.fused_moe_lora_data_prepare(
block_size=kernel_config["BLOCK_SIZE_M"],
token_lora_mapping=self.lora_kernel_meta.token_lora_mapping,
ctx=ctx,
)
)
return {
"a_intermediate_cache1": self.input,
"lora_b_stacked": self.lora_weights_lst,
"output": self.output,
"topk_weights": topk_weights,
"sorted_token_ids": sorted_token_ids,
"expert_ids": expert_ids,
"num_tokens_post_padded": num_tokens_post_padded,
"top_k_num": ctx.top_k_num,
"device": self.input.device,
"N": lora_rank,
"M": topk_weights.shape[0],
"EM": sorted_token_ids.shape[1],
"K": self.input.shape[1],
"num_tokens": num_tokens,
"num_experts": ctx.num_experts,
"num_slices": num_slices,
"max_lora_rank": lora_rank,
"w1_output_dim_size": lw_shape[2],
"expand_block_size_m": kernel_config["BLOCK_SIZE_M"],
"expand_block_size_n": kernel_config["BLOCK_SIZE_N"],
"expand_block_size_k": kernel_config["BLOCK_SIZE_K"],
"expand_group_size_m": kernel_config["GROUP_SIZE_M"],
"expand_num_warps": kernel_config["NUM_WARPS"],
"expand_num_stages": kernel_config["NUM_STAGES"],
"expand_split_k": kernel_config.get("SPLIT_K", 1),
"mul_routed_weight": op_type.is_fused_moe_lora_down_fn(),
}
def bench_fn_kwargs( def bench_fn_kwargs(
self, op_type: OpType, add_inputs: Optional[bool] = None self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None
) -> dict[str, Any]: ) -> dict[str, Any]:
if op_type.is_shrink_fn(): if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn():
assert add_inputs is None assert add_inputs is None
else: else:
assert add_inputs is not None assert add_inputs is not None
if op_type == OpType.LORA_SHRINK: if op_type == OpType.LORA_SHRINK:
return self.as_lora_shrink_kwargs() return self.as_lora_shrink_kwargs(ctx, op_type)
if op_type == OpType.LORA_EXPAND: if op_type == OpType.LORA_EXPAND:
return self.as_lora_expand_kwargs(add_inputs) return self.as_lora_expand_kwargs(ctx, op_type, add_inputs)
if op_type.is_fused_moe_lora_shrink_fn():
return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type)
if op_type.is_fused_moe_lora_expand_fn():
return self.as_fused_moe_lora_expand_kwargs(ctx, op_type)
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
def test_correctness( def test_correctness(
self, op_type: OpType, expand_fn_add_inputs: Optional[bool] self, op_type: OpType, expand_fn_add_inputs: bool | None
) -> bool: ) -> bool:
""" """
Test correctness of op_type implementation against a grouped gemm Test correctness of op_type implementation against a grouped gemm
...@@ -611,12 +991,12 @@ def bench_optype( ...@@ -611,12 +991,12 @@ def bench_optype(
ctx: BenchmarkContext, ctx: BenchmarkContext,
arg_pool_size: int, arg_pool_size: int,
op_type: OpType, op_type: OpType,
cuda_graph_nops: Optional[int] = None, cuda_graph_nops: int | None = None,
expand_fn_add_inputs: Optional[bool] = None, expand_fn_add_inputs: bool | None = None,
test_correctness: bool = False, test_correctness: bool = False,
) -> TMeasurement: ) -> TMeasurement:
assert arg_pool_size >= 1 assert arg_pool_size >= 1
if op_type.is_shrink_fn(): if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn():
assert expand_fn_add_inputs is None assert expand_fn_add_inputs is None
else: else:
assert expand_fn_add_inputs is not None assert expand_fn_add_inputs is not None
...@@ -626,23 +1006,30 @@ def bench_optype( ...@@ -626,23 +1006,30 @@ def bench_optype(
BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
] ]
for bt in bench_tensors: for bt in bench_tensors:
bt.sanity_check() bt.sanity_check(ctx, op_type)
# Test correctness of our implementation. # Test correctness of our implementation.
if test_correctness: if test_correctness:
assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], (
f"Correctness testing is not supported for {op_type.name}."
)
assert all( assert all(
[bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] [
bt.test_correctness(ctx, op_type, expand_fn_add_inputs)
for bt in bench_tensors
]
) )
# BenchmarkTensors -> dict (kwargs) # BenchmarkTensors -> dict (kwargs)
kwargs_list = [ kwargs_list = [
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs)
for bt in bench_tensors for bt in bench_tensors
] ]
# Clear LoRA optimization hash-maps. # Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT.clear() _LORA_A_PTR_DICT.clear()
_LORA_B_PTR_DICT.clear() _LORA_B_PTR_DICT.clear()
_LORA_PTR_DICT.clear()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for kwargs in kwargs_list: for kwargs in kwargs_list:
op_type.bench_fn()(**kwargs) op_type.bench_fn()(**kwargs)
...@@ -679,7 +1066,7 @@ def bench_torch_mm( ...@@ -679,7 +1066,7 @@ def bench_torch_mm(
ctx: BenchmarkContext, ctx: BenchmarkContext,
arg_pool_size: int, arg_pool_size: int,
op_type: OpType, op_type: OpType,
cuda_graph_nops: Optional[int] = None, cuda_graph_nops: int | None = None,
) -> TMeasurement: ) -> TMeasurement:
""" """
Benchmark basic torch.mm as a roofline. Benchmark basic torch.mm as a roofline.
...@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str: ...@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str:
""" """
def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None): def print_timers(timers: list[TMeasurement], args: argparse.Namespace | None = None):
compare = TBenchmark.Compare(timers) compare = TBenchmark.Compare(timers)
compare.print() compare.print()
...@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): ...@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
# Benchmark bench_op # Benchmark bench_op
expand_fn_add_inputs = ( expand_fn_add_inputs = (
[None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs [None]
if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn()
else args.expand_fn_add_inputs
) )
for add_input_arg in expand_fn_add_inputs: for add_input_arg in expand_fn_add_inputs:
seq_len_timers.append( seq_len_timers.append(
...@@ -830,12 +1219,22 @@ def as_benchmark_contexts( ...@@ -830,12 +1219,22 @@ def as_benchmark_contexts(
hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
) -> list[BenchmarkContext]: ) -> list[BenchmarkContext]:
ctxs: list[BenchmarkContext] = [] ctxs: list[BenchmarkContext] = []
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa for (
batch_size,
hidden_size,
lora_rank,
num_loras,
sort_by_lora_id,
top_k_num,
num_experts,
) in product( # noqa
args.batch_sizes, args.batch_sizes,
list(hidden_sizes), list(hidden_sizes),
lora_ranks, lora_ranks,
args.num_loras, args.num_loras,
args.sort_by_lora_id, args.sort_by_lora_id,
args.top_k_nums,
args.num_experts,
): ):
ctxs.append( ctxs.append(
BenchmarkContext( BenchmarkContext(
...@@ -850,6 +1249,8 @@ def as_benchmark_contexts( ...@@ -850,6 +1249,8 @@ def as_benchmark_contexts(
seq_length=None, seq_length=None,
sort_by_lora_id=sort_by_lora_id, sort_by_lora_id=sort_by_lora_id,
dtype=args.dtype, dtype=args.dtype,
top_k_num=top_k_num,
num_experts=num_experts,
# To be filled based on the OpType to benchmark # To be filled based on the OpType to benchmark
num_slices=None, num_slices=None,
) )
...@@ -1011,6 +1412,22 @@ if __name__ == "__main__": ...@@ -1011,6 +1412,22 @@ if __name__ == "__main__":
), ),
) )
p.add_argument(
"--top-k-nums",
nargs="+",
type=int,
default=DEFAULT_TOP_K_NUMS,
help="Top-K values for MoE LoRA operations",
)
p.add_argument(
"--num-experts",
nargs="+",
type=int,
default=DEFAULT_NUM_EXPERTS,
help="Number of experts for MoE LoRA operations",
)
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description=f""" description=f"""
Benchmark LoRA kernels: Benchmark LoRA kernels:
......
...@@ -8,10 +8,9 @@ import math ...@@ -8,10 +8,9 @@ import math
import os import os
import pickle as pkl import pickle as pkl
import time import time
from collections.abc import Iterable from collections.abc import Callable, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Callable, Optional
import pandas as pd import pandas as pd
import torch import torch
...@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights, quantize_weights,
) )
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
...@@ -63,23 +62,23 @@ class BenchmarkTensors: ...@@ -63,23 +62,23 @@ class BenchmarkTensors:
a: torch.Tensor a: torch.Tensor
w_q: torch.Tensor w_q: torch.Tensor
group_size: Optional[int] group_size: int | None
wtype: ScalarType wtype: ScalarType
w_g_s: torch.Tensor w_g_s: torch.Tensor
w_g_zp: Optional[torch.Tensor] w_g_zp: torch.Tensor | None
w_ch_s: Optional[torch.Tensor] w_ch_s: torch.Tensor | None
w_tok_s: Optional[torch.Tensor] w_tok_s: torch.Tensor | None
@dataclass @dataclass
class TypeConfig: class TypeConfig:
act_type: torch.dtype act_type: torch.dtype
weight_type: ScalarType weight_type: ScalarType
output_type: Optional[torch.dtype] output_type: torch.dtype | None
group_scale_type: Optional[torch.dtype] group_scale_type: torch.dtype | None
group_zero_type: Optional[torch.dtype] group_zero_type: torch.dtype | None
channel_scale_type: Optional[torch.dtype] channel_scale_type: torch.dtype | None
token_scale_type: Optional[torch.dtype] token_scale_type: torch.dtype | None
def rand_data(shape, dtype=torch.float16, scale=1): def rand_data(shape, dtype=torch.float16, scale=1):
...@@ -93,8 +92,8 @@ def quantize_and_pack( ...@@ -93,8 +92,8 @@ def quantize_and_pack(
atype: torch.dtype, atype: torch.dtype,
w: torch.Tensor, w: torch.Tensor,
wtype: ScalarType, wtype: ScalarType,
stype: Optional[torch.dtype], stype: torch.dtype | None,
group_size: Optional[int], group_size: int | None,
zero_points: bool = False, zero_points: bool = False,
): ):
assert wtype.is_integer(), "TODO: support floating point weights" assert wtype.is_integer(), "TODO: support floating point weights"
...@@ -113,7 +112,7 @@ def quantize_and_pack( ...@@ -113,7 +112,7 @@ def quantize_and_pack(
def create_bench_tensors( def create_bench_tensors(
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
) -> list[BenchmarkTensors]: ) -> list[BenchmarkTensors]:
m, n, k = shape m, n, k = shape
...@@ -331,8 +330,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]) ...@@ -331,8 +330,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
return res return res
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None _SWEEP_SCHEDULES_RESULTS_CSV: str | None = None
def bench( def bench(
......
...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
sort_weights, sort_weights,
) )
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
......
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import * ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -211,7 +211,7 @@ def get_rocm_tuning_space(use_fp16): ...@@ -211,7 +211,7 @@ def get_rocm_tuning_space(use_fp16):
num_warps_range = [1, 2, 4, 8] num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32] group_m_range = [1, 4, 8, 16, 32]
num_stage_range = [2] num_stage_range = [2]
waves_per_eu_range = [0] waves_per_eu_range = [0, 1, 2, 4]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else [] kpack_range = [1, 2] if use_fp16 else []
...@@ -579,19 +579,23 @@ def main(args: argparse.Namespace): ...@@ -579,19 +579,23 @@ def main(args: argparse.Namespace):
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM": elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM", "DeepseekV32ForCausalLM",
"Glm4MoeForCausalLM", "Glm4MoeForCausalLM",
"NemotronHForCausalLM",
): ):
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM", "Qwen3MoeForCausalLM",
...@@ -600,10 +604,23 @@ def main(args: argparse.Namespace): ...@@ -600,10 +604,23 @@ def main(args: argparse.Namespace):
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
text_config = config.get_text_config()
E = text_config.num_experts
topk = text_config.num_experts_per_tok
intermediate_size = text_config.moe_intermediate_size
hidden_size = text_config.hidden_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
E = config.num_experts E = config.num_experts
topk = config.moe_topk[0] topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0] intermediate_size = config.moe_intermediate_size[0]
hidden_size = config.hidden_size
elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
E = config.thinker_config.text_config.num_experts
topk = config.thinker_config.text_config.num_experts_per_tok
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
hidden_size = config.thinker_config.text_config.hidden_size
else: else:
# Support for llama4 # Support for llama4
config = config.get_text_config() config = config.get_text_config()
...@@ -611,6 +628,7 @@ def main(args: argparse.Namespace): ...@@ -611,6 +628,7 @@ def main(args: argparse.Namespace):
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
enable_ep = bool(args.enable_expert_parallel) enable_ep = bool(args.enable_expert_parallel)
if enable_ep: if enable_ep:
ensure_divisibility(E, args.tp_size, "Number of experts") ensure_divisibility(E, args.tp_size, "Number of experts")
...@@ -619,8 +637,7 @@ def main(args: argparse.Namespace): ...@@ -619,8 +637,7 @@ def main(args: argparse.Namespace):
else: else:
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.dtype
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)
......
...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( ...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
) )
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -344,7 +344,7 @@ def main(args: argparse.Namespace): ...@@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute use_customized_permute = args.use_customized_permute
......
...@@ -39,7 +39,7 @@ import torch ...@@ -39,7 +39,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......
...@@ -3,16 +3,15 @@ ...@@ -3,16 +3,15 @@
import random import random
import time import time
from typing import Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random, create_kv_caches_with_random,
) )
...@@ -37,7 +36,7 @@ def main( ...@@ -37,7 +36,7 @@ def main(
seed: int, seed: int,
do_profile: bool, do_profile: bool,
device: str = "cuda", device: str = "cuda",
kv_cache_dtype: Optional[str] = None, kv_cache_dtype: str | None = None,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import argparse import argparse
import math import math
from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable
from unittest.mock import patch from unittest.mock import patch
import torch import torch
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import torch
from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
def polynorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
def norm(x, eps: float):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.float()
return (
(
weight[0] * norm(x**3, eps)
+ weight[1] * norm(x**2, eps)
+ weight[2] * norm(x, eps)
+ bias
)
.to(weight.dtype)
.view(orig_shape)
)
def polynorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
out = torch.empty_like(x)
vllm_ops.poly_norm(out, x, weight, bias, eps)
output = out
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_dim):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
output_naive = polynorm_naive(x, weight, bias)
output_vllm = polynorm_vllm(x, weight, bias)
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
dim_range = [2048, 4096]
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
def get_benchmark():
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["dim", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "vllm"],
line_names=["Naive", "vLLM"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="polynorm-perf",
args={},
)
)
def benchmark(dim, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_dim = dim * 4
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_naive(x, weight, bias),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_vllm(x, weight, bias),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size",
)
parser.add_argument(
"--seq-len",
type=int,
default=128,
help="Sequence length",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=8192,
help="Intermediate size of MLP",
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/polnorm/",
help="Path to save polnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=args.batch_size,
seq_len=args.seq_len,
hidden_dim=args.hidden_dim,
)
benchmark = get_benchmark()
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode() @torch.inference_mode()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import time
import torch
from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
create_kv_caches_with_random,
)
logger = init_logger(__name__)
@torch.inference_mode()
def run_benchmark(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
kv_cache_dtype: str,
num_iters: int,
benchmark_mode: str,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""
if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
current_platform.seed_everything(42)
torch.set_default_device(device)
# create random key / value tensors [T, H, D].
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
value = torch.randn_like(key)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots = block_size * num_blocks
if num_tokens > num_slots:
raise ValueError("num_tokens cannot exceed the total number of cache slots")
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
key_caches, value_caches = create_kv_caches_with_random(
num_blocks,
block_size,
1, # num_layers
num_heads,
head_size,
kv_cache_dtype,
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
function_under_test = lambda: ops.reshape_and_cache(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
function_under_test()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters
# warm-up
run_cuda_benchmark(3)
lat = run_cuda_benchmark(num_iters)
# free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache()
return lat
def main(args):
rows = []
for exp in range(1, 17):
n_tok = 2**exp
lat = run_benchmark(
num_tokens=n_tok,
num_heads=args.num_heads,
head_size=args.head_size,
block_size=args.block_size,
num_blocks=args.num_blocks,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
kv_cache_dtype=args.kv_cache_dtype,
num_iters=args.iters,
benchmark_mode=args.mode,
device="cuda",
)
rows.append([n_tok, lat * 1e6]) # convert to microseconds
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--num-heads", type=int, default=128)
parser.add_argument(
"--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--num-blocks", type=int, default=128 * 128)
parser.add_argument(
"--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="bfloat16",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
default="auto",
)
parser.add_argument("--iters", type=int, default=200)
parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random import random
import time import time
...@@ -14,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import ( ...@@ -14,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random_flash, create_kv_caches_with_random_flash,
) )
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from typing import Optional, Union
import torch import torch
from flashinfer.norm import fused_add_rmsnorm, rmsnorm from flashinfer.norm import fused_add_rmsnorm, rmsnorm
...@@ -21,8 +20,8 @@ class HuggingFaceRMSNorm(nn.Module): ...@@ -21,8 +20,8 @@ class HuggingFaceRMSNorm(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: torch.Tensor | None = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
if residual is not None: if residual is not None:
...@@ -41,7 +40,7 @@ class HuggingFaceRMSNorm(nn.Module): ...@@ -41,7 +40,7 @@ class HuggingFaceRMSNorm(nn.Module):
def rmsnorm_naive( def rmsnorm_naive(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: torch.Tensor | None = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
...@@ -65,7 +64,7 @@ def rmsnorm_naive( ...@@ -65,7 +64,7 @@ def rmsnorm_naive(
def rmsnorm_flashinfer( def rmsnorm_flashinfer(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: torch.Tensor | None = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
orig_shape = x.shape orig_shape = x.shape
...@@ -89,7 +88,7 @@ def rmsnorm_flashinfer( ...@@ -89,7 +88,7 @@ def rmsnorm_flashinfer(
def rmsnorm_vllm( def rmsnorm_vllm(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: torch.Tensor | None = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
orig_shape = x.shape orig_shape = x.shape
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import accumulate import itertools
from typing import Optional
import nvtx
import torch import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
batch_size_range = [2**i for i in range(0, 8, 2)]
seq_len_range = [2**i for i in range(6, 10, 1)]
num_heads_range = [32, 48]
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
def benchmark_rope_kernels_multi_lora(
is_neox_style: bool, def get_benchmark(head_size, rotary_dim, is_neox_style, device):
batch_size: int, @triton.testing.perf_report(
seq_len: int, triton.testing.Benchmark(
num_heads: int, x_names=["batch_size", "seq_len", "num_heads"],
head_size: int, x_vals=[list(_) for _ in configs],
rotary_dim: Optional[int], line_arg="provider",
dtype: torch.dtype, line_vals=["torch", "flashinfer", "vllm"],
seed: int, line_names=["PyTorch", "FlashInfer", "vLLM"],
device: str, styles=[("blue", "-"), ("green", "-"), ("red", "-")],
max_position: int = 8192, ylabel="us",
base: float = 10000, plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
) -> None: args={},
current_platform.seed_everything(seed) )
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
# silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors
batched_rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
{"rope_type": "linear", "factor": tuple(scaling_factors)},
) )
# non-batched RoPE takes only one scaling factor, we create multiple def benchmark(batch_size, seq_len, num_heads, provider):
# instances to simulate the same behavior dtype = torch.bfloat16
non_batched_ropes: list[RotaryEmbedding] = [] max_position = 8192
for scaling_factor in scaling_factors: base = 10000
non_batched_ropes.append( rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
get_rope( rope = rope.to(dtype=dtype, device=device)
head_size, cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
rotary_dim,
max_position, positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
base, query = torch.randn(
is_neox_style, (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
{"rope_type": "linear", "factor": (scaling_factor,)},
)
) )
key = torch.randn_like(query)
positions = torch.randint(0, max_position, (batch_size, seq_len)) quantiles = [0.5, 0.2, 0.8]
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache if provider == "torch":
# together and each query needs to find the right kv cache of its type ms, min_ms, max_ms = triton.testing.do_bench(
offset_map = torch.tensor( lambda: rope.forward_native(positions, query.clone(), key.clone()),
list( quantiles=quantiles,
accumulate(
[0]
+ [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
]
) )
) elif provider == "flashinfer":
) ms, min_ms, max_ms = triton.testing.do_bench(
query_types = torch.randint( lambda: torch.ops.vllm.flashinfer_rotary_embedding(
0, len(scaling_factors), (batch_size, seq_len), device=device positions,
) query.clone(),
# map query types to offsets key.clone(),
query_offsets = offset_map[query_types] head_size,
# the kernel takes flattened offsets cos_sin_cache,
flatten_offsets = query_offsets.flatten() is_neox_style,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
# batched queries of the same type together for non-batched RoPE return benchmark
queries = [query[query_types == i] for i in range(len(scaling_factors))]
keys = [key[query_types == i] for i in range(len(scaling_factors))]
packed_qkr = zip(queries, keys, non_batched_ropes)
# synchronize before start timing
torch.cuda.synchronize()
with nvtx.annotate("non-batched", color="yellow"):
for q, k, r in packed_qkr:
r.forward(positions, q, k)
torch.cuda.synchronize()
with nvtx.annotate("batched", color="green"):
batched_rope.forward(positions, query, key, flatten_offsets)
torch.cuda.synchronize()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -117,17 +95,12 @@ if __name__ == "__main__": ...@@ -117,17 +95,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
) )
parser.add_argument("--save-path", type=str, default="./configs/rope/")
args = parser.parse_args() args = parser.parse_args()
print(args)
benchmark_rope_kernels_multi_lora( # Get the benchmark function
is_neox_style=args.is_neox_style, benchmark = get_benchmark(
batch_size=args.batch_size, args.head_size, args.rotary_dim, args.is_neox_style, args.device
seq_len=args.seq_len,
num_heads=args.num_heads,
head_size=args.head_size,
rotary_dim=args.rotary_dim,
dtype=getattr(torch, args.dtype),
seed=args.seed,
device=args.device,
) )
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
...@@ -78,11 +78,11 @@ WEIGHT_SHAPES = { ...@@ -78,11 +78,11 @@ WEIGHT_SHAPES = {
} }
WEIGHT_SHAPES_MOE = { WEIGHT_SHAPES_MOE = {
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [ "mistralai/Mixtral-8x7B-Instruct-v0.1": [
[8, 2, 4096, 28672], [8, 2, 4096, 28672],
[8, 2, 14336, 4096], [8, 2, 14336, 4096],
], ],
"nm-testing/deepseekv2-lite": [ "deepseek-ai/DeepSeek-V2-Lite": [
[64, 6, 2048, 1408], [64, 6, 2048, 1408],
], ],
"ibm-granite/granite-3.0-1b-a400m": [ "ibm-granite/granite-3.0-1b-a400m": [
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Comprehensive 3-way SiLU Benchmark Suite
This benchmark compares three SiLU implementations:
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
2. Triton Kernel - Triton-based implementation
The suite generates detailed performance comparisons including:
- Memory bandwidth utilization
- Speedup ratios (baseline vs optimized implementations)
- Performance across different expert configurations and token distributions
"""
from collections.abc import Callable from collections.abc import Callable
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -7,7 +21,7 @@ import numpy as np ...@@ -7,7 +21,7 @@ import numpy as np
import torch import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm_cuda, persistent_masked_m_silu_mul_quant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton( ...@@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
num_parallel_tokens, num_parallel_tokens,
group_size: int = 128, group_size: int = 128,
eps: float = 1e-10, eps: float = 1e-10,
expert_offsets: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
...@@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton( ...@@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
# Parse generation strategies # Parse generation strategies
strategies = ["uniform", "max_t", "first_t"] strategies = ["random_imbalanced", "uniform", "max_t"]
def benchmark( def benchmark(
...@@ -195,15 +210,27 @@ def benchmark( ...@@ -195,15 +210,27 @@ def benchmark(
current_platform.seed_everything(42 + seed_offset) current_platform.seed_everything(42 + seed_offset)
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
if gen_strategy == "uniform": if gen_strategy == "random_imbalanced":
r = torch.rand(size=(E,), device="cuda")
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
mean = total_tokens // n_e
min_max = mean // ratio
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
e[0] = min_max
r = torch.rand(size=(E - 1,))
r /= r.sum()
r *= total_tokens - min_max
r = r.round().long()
e[1:] = r.to(device=device)
return e
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
elif gen_strategy == "uniform":
r = torch.rand(size=(E,))
r /= r.sum() r /= r.sum()
r *= total_tokens r *= total_tokens
tokens_per_expert = r.int() r = r.round().long()
tokens_per_expert = torch.minimum( tokens_per_expert = r
tokens_per_expert,
torch.ones((E,), device=r.device, dtype=torch.int) * T,
)
elif gen_strategy == "max_t": elif gen_strategy == "max_t":
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
tokens_per_expert.fill_(total_tokens / E) tokens_per_expert.fill_(total_tokens / E)
...@@ -281,40 +308,34 @@ def benchmark( ...@@ -281,40 +308,34 @@ def benchmark(
def create_comparison_plot( def create_comparison_plot(
ratio, cuda_times, baseline_times, config_labels, strategy_name, id ratios, silu_v2_times, triton_times, config_labels, strategy_name, id
): ):
"""Create a comparison plot for a specific generation strategy""" fig, ax = plt.subplots(1, 1, figsize=(18, 6))
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
# Configure x-axis positions # Configure x-axis positions
x = np.arange(len(config_labels)) x = np.arange(len(config_labels))
width = 0.35 width = 0.25
# Execution Time plot (lower is better) # Execution Time plot (lower is better)
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
ax.bar( ax.bar(
x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue" x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green"
)
ax.bar(
x + width / 2,
baseline_times,
width,
label="Baseline",
alpha=0.8,
color="orange",
) )
# Add speedup labels over each bar pair # Add speedup labels over each bar trio
for i in range(len(x)): for i in range(len(x)):
speedup = ratio[i] triton_v2_speedup = ratios[i][1] # triton/v2
max_height = max(cuda_times[i], baseline_times[i]) max_height = max(silu_v2_times[i], triton_times[i])
# Triton/V2 speedup
ax.text( ax.text(
x[i], x[i] + width / 2,
max_height + max_height * 0.02, max_height + max_height * 0.02,
f"{speedup:.2f}x", f"{triton_v2_speedup:.2f}x",
ha="center", ha="center",
va="bottom", va="bottom",
fontweight="bold", fontweight="bold",
fontsize=9, fontsize=8,
) )
ax.set_xlabel("Configuration") ax.set_xlabel("Configuration")
...@@ -332,56 +353,75 @@ def create_comparison_plot( ...@@ -332,56 +353,75 @@ def create_comparison_plot(
def create_combined_plot(all_results): def create_combined_plot(all_results):
"""Create a combined plot with all strategies in one PNG"""
num_strategies = len(all_results) num_strategies = len(all_results)
fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies)) fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies))
if num_strategies == 1: if num_strategies == 1:
axes = [axes] axes = [axes]
for idx, ( for idx, (
strategy_name, strategy_name,
ratio, all_ratios,
cuda_times, all_silu_v2_results,
baseline_times, all_triton_results,
config_labels, config_labels,
config_x_axis,
) in enumerate(all_results): ) in enumerate(all_results):
ax = axes[idx] ax = axes[idx]
# Flatten the nested results to get bandwidth percentages for plotting
silu_v2_bandwidths = []
triton_bandwidths = []
flat_ratios = []
for config_results in all_silu_v2_results:
for result in config_results:
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
for config_results in all_triton_results:
for result in config_results:
triton_bandwidths.append(result[3]) # bandwidth percentage
for config_ratios in all_ratios:
for ratio in config_ratios:
flat_ratios.append(ratio)
# Configure x-axis positions # Configure x-axis positions
x = np.arange(len(config_labels)) x = np.arange(len(config_labels))
width = 0.35 width = 0.25
# Execution Time plot (lower is better) # Bandwidth utilization plot (higher is better)
ax.bar( ax.bar(
x - width / 2, x,
cuda_times, silu_v2_bandwidths,
width, width,
label="CUDA Kernel", label="SiLU V2 (CUDA)",
alpha=0.8, alpha=0.8,
color="blue", color="blue",
) )
ax.bar( ax.bar(
x + width / 2, x + width,
baseline_times, triton_bandwidths,
width, width,
label="Baseline", label="Triton Kernel",
alpha=0.8, alpha=0.8,
color="orange", color="green",
) )
# Add speedup labels over each bar pair # Add speedup labels over each bar trio
for i in range(len(x)): for i in range(len(x)):
speedup = ratio[i] triton_v2_speedup = flat_ratios[i] # triton/v2
max_height = max(cuda_times[i], baseline_times[i]) max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i])
# Triton/V2 speedup
ax.text( ax.text(
x[i], x[i] + width / 2,
max_height + max_height * 0.02, max_height + max_height * 0.02,
f"{speedup:.2f}x", f"{triton_v2_speedup:.2f}x",
ha="center", ha="center",
va="bottom", va="bottom",
fontweight="bold", fontweight="bold",
fontsize=9, fontsize=8,
) )
ax.set_xlabel("Configuration") ax.set_xlabel("Configuration")
...@@ -395,7 +435,7 @@ def create_combined_plot(all_results): ...@@ -395,7 +435,7 @@ def create_combined_plot(all_results):
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
plt.tight_layout() plt.tight_layout()
filename = "../../silu_bench/silu_benchmark_combined.png" filename = "silu_benchmark_combined_3way.png"
plt.savefig(filename, dpi=300, bbox_inches="tight") plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.show() plt.show()
...@@ -405,7 +445,9 @@ def create_combined_plot(all_results): ...@@ -405,7 +445,9 @@ def create_combined_plot(all_results):
outer_dim = 7168 outer_dim = 7168
configs = [ configs = [
# DeepSeekV3 Configs # DeepSeekV3 Configs
# (1, 56, 7168),
(8, 1024, 7168), (8, 1024, 7168),
# (32, 56, 7168),
# DeepSeekV3 Configs # DeepSeekV3 Configs
(32, 1024, 7168), (32, 1024, 7168),
# DeepSeekV3 Configs # DeepSeekV3 Configs
...@@ -417,6 +459,7 @@ num_warmups = 20 ...@@ -417,6 +459,7 @@ num_warmups = 20
strategy_descriptions = { strategy_descriptions = {
"uniform": "Uniform Random", "uniform": "Uniform Random",
"random_imbalanced": "Imbalanced Random",
"max_t": "Even Assignment", "max_t": "Even Assignment",
"first_t": "experts[0] = T, experts[1:] = 0", "first_t": "experts[0] = T, experts[1:] = 0",
} }
...@@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies): ...@@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies):
print(f"Testing strategy: {strategy_descriptions[strategy]}") print(f"Testing strategy: {strategy_descriptions[strategy]}")
print(f"{'=' * 60}") print(f"{'=' * 60}")
# Collect benchmark data for both algorithms # Collect benchmark data for all three algorithms
config_labels = [] config_labels = []
config_x_axis = [] config_x_axis = []
all_cuda_results = [] all_silu_v2_results = []
all_baseline_results = [] all_triton_results = []
all_ratios = [] all_ratios = []
for E, T, H in configs: for E, T, H in configs:
total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E] total_tokens_config = []
for i in [8, 16, 32, 64, 128, 256, 512]:
if i <= T:
total_tokens_config.append(i * E)
config_x_axis.append(total_tokens_config) config_x_axis.append(total_tokens_config)
cuda_results = [] silu_v2_results = []
baseline_results = [] triton_results = []
ratios = [] ratios = []
for total_tokens in total_tokens_config: for total_tokens in total_tokens_config:
config_label = f"E={E},T={T},H={H},TT={total_tokens}" config_label = f"E={E},T={T},H={H},TT={total_tokens}"
config_labels.append(config_label) config_labels.append(config_label)
# CUDA kernel results # SiLU V2 (CUDA kernel) results
time_ms_cuda, gflops, gbps, perc = benchmark( time_ms_silu_v2, gflops, gbps, perc = benchmark(
silu_mul_fp8_quant_deep_gemm_cuda, persistent_masked_m_silu_mul_quant,
E, E,
T, T,
H, H,
...@@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies): ...@@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies):
num_warmups=num_warmups, num_warmups=num_warmups,
gen_strategy=strategy, gen_strategy=strategy,
) )
cuda_results.append((time_ms_cuda, gflops, gbps, perc)) silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc))
# Baseline results # Triton kernel results
time_ms_triton, gflops, gbps, perc = benchmark( time_ms_triton, gflops, gbps, perc = benchmark(
silu_mul_fp8_quant_deep_gemm_triton, silu_mul_fp8_quant_deep_gemm_triton,
E, E,
...@@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies): ...@@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies):
num_warmups=num_warmups, num_warmups=num_warmups,
gen_strategy=strategy, gen_strategy=strategy,
) )
baseline_results.append((time_ms_triton, gflops, gbps, perc)) triton_results.append((time_ms_triton, gflops, gbps, perc))
ratios.append(time_ms_triton / time_ms_cuda)
print(f"Completed: {config_label}") # Calculate speedup ratios (triton baseline / implementation)
all_cuda_results.append(cuda_results) triton_v2_ratio = time_ms_triton / time_ms_silu_v2
all_baseline_results.append(baseline_results) ratios.append(triton_v2_ratio)
print(
f"Completed: {config_label}:"
f" V2: {time_ms_silu_v2:.3f}ms,"
f" Triton: {time_ms_triton:.3f}ms"
)
all_silu_v2_results.append(silu_v2_results)
all_triton_results.append(triton_results)
all_ratios.append(ratios) all_ratios.append(ratios)
# Store results for combined plotting # Store results for combined plotting
...@@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies): ...@@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies):
( (
strategy_descriptions[strategy], strategy_descriptions[strategy],
all_ratios, all_ratios,
all_cuda_results, all_silu_v2_results,
all_baseline_results, all_triton_results,
config_labels, config_labels,
config_x_axis, config_x_axis,
) )
...@@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies): ...@@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies):
# Print summary table for this strategy # Print summary table for this strategy
print(f"\nSummary Table - {strategy_descriptions[strategy]}:") print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}") print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}")
print("-" * 60) print("-" * 90)
for i, (E, T, H) in enumerate(configs): for i, (E, T, H) in enumerate(configs):
speedup = baseline_results[i][0] / cuda_results[i][0] # Get the first result for each config (simplifying for summary)
v2_time = silu_v2_results[i][0]
triton_time = triton_results[i][0]
triton_v2_speedup = triton_time / v2_time
config_label = f"E={E:3d},T={T:4d},H={H:4d}" config_label = f"E={E:3d},T={T:4d},H={H:4d}"
print( print(
f"{config_label:<20} {cuda_results[i][0]:8.5f} " f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} "
f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x" f"{triton_v2_speedup:8.2f}x"
) )
...@@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results): ...@@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results):
num_strategies = len(all_results) num_strategies = len(all_results)
num_configs = len(configs) num_configs = len(configs)
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
fig, axs = plt.subplots( fig, axs = plt.subplots(
num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies) num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies)
) )
# Add main title to the entire figure # Add main title to the entire figure
fig.suptitle( fig.suptitle(
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)", "Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)",
fontsize=16, fontsize=18,
fontweight="bold", fontweight="bold",
y=0.98, y=0.98,
) )
...@@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results): ...@@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results):
( (
strategy_name, strategy_name,
all_ratios, all_ratios,
all_cuda_results, all_silu_v2_results,
all_baseline_results, all_triton_results,
config_labels, config_labels,
config_x_axis, config_x_axis,
) = result ) = result
...@@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results): ...@@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results):
ratios = all_ratios[config_idx] ratios = all_ratios[config_idx]
total_tokens_values = config_x_axis[config_idx] total_tokens_values = config_x_axis[config_idx]
# Extract CUDA and Triton bandwidth percentages # Extract speedup ratios
cuda_bandwidth_percentages = [ triton_v2_ratios = [ratio for ratio in ratios]
result[3] for result in all_cuda_results[config_idx]
# Extract bandwidth percentages for all implementations
v2_bandwidth_percentages = [
result[3] for result in all_silu_v2_results[config_idx]
] ]
triton_bandwidth_percentages = [ triton_bandwidth_percentages = [
result[3] for result in all_baseline_results[config_idx] result[3] for result in all_triton_results[config_idx]
] ]
# Plot speedup ratios vs total tokens (left plot) # Plot speedup ratios vs total tokens (left plot)
ax_speedup.plot( ax_speedup.plot(
total_tokens_values, ratios, "bo-", linewidth=3, markersize=8 total_tokens_values,
triton_v2_ratios,
"go-",
linewidth=3,
markersize=8,
label="Triton/V2 Speedup",
) )
ax_speedup.set_title( ax_speedup.set_title(
f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}", f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}",
fontsize=12, fontsize=12,
fontweight="bold", fontweight="bold",
) )
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
ax_speedup.legend(prop={"weight": "bold"})
ax_speedup.grid(True, alpha=0.3) ax_speedup.grid(True, alpha=0.3)
# Plot bandwidth utilization (right plot)
ax_bandwidth.plot( ax_bandwidth.plot(
total_tokens_values, total_tokens_values,
cuda_bandwidth_percentages, v2_bandwidth_percentages,
"ro-", "o-",
linewidth=3, linewidth=3,
markersize=8, markersize=8,
label="CUDA", label="SiLU V2",
color="blue",
) )
ax_bandwidth.plot( ax_bandwidth.plot(
total_tokens_values, total_tokens_values,
triton_bandwidth_percentages, triton_bandwidth_percentages,
"go-", "o-",
linewidth=3, linewidth=3,
markersize=8, markersize=8,
label="Triton", label="Triton",
color="green",
) )
ax_bandwidth.set_title( ax_bandwidth.set_title(
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
...@@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results): ...@@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results):
for label in ax.get_xticklabels() + ax.get_yticklabels(): for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_fontweight("bold") label.set_fontweight("bold")
# Add value labels on speedup points # Add value labels on Triton/V2 speedup points
for x, y in zip(total_tokens_values, ratios): for x, y in zip(total_tokens_values, triton_v2_ratios):
ax_speedup.annotate( ax_speedup.annotate(
f"{y:.2f}x", f"{y:.2f}x",
(x, y), (x, y),
textcoords="offset points", textcoords="offset points",
xytext=(0, 12),
ha="center",
fontsize=10,
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
)
# Add value labels on CUDA bandwidth points
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
ax_bandwidth.annotate(
f"{y:.1f}%",
(x, y),
textcoords="offset points",
xytext=(0, 12),
ha="center",
fontsize=9,
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
)
# Add value labels on Triton bandwidth points
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
ax_bandwidth.annotate(
f"{y:.1f}%",
(x, y),
textcoords="offset points",
xytext=(0, -15), xytext=(0, -15),
ha="center", ha="center",
fontsize=9, fontsize=9,
...@@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results): ...@@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results):
plt.tight_layout() plt.tight_layout()
plt.subplots_adjust(top=0.93) # Make room for main title plt.subplots_adjust(top=0.93) # Make room for main title
filename = "silu_benchmark_total_tokens.png" filename = "silu_benchmark_total_tokens_3way.png"
plt.savefig(filename, dpi=300, bbox_inches="tight") plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.show() plt.show()
return filename return filename
# Create combined plot with all strategies # Create comprehensive 3-way comparison plots
combined_plot_filename = create_total_tokens_plot(all_results) combined_plot_filename = create_combined_plot(all_results)
total_tokens_plot_filename = create_total_tokens_plot(all_results)
print(f"\n{'=' * 60}") print(f"\n{'=' * 80}")
print("Benchmark Complete!") print("3-Way Benchmark Suite Complete!")
print(f"Generated combined plot: {combined_plot_filename}") print(f"Generated combined comparison plot: {combined_plot_filename}")
print(f"{'=' * 60}") print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}")
print("Compared: SiLU V2 (CUDA), and Triton implementations")
print(f"{'=' * 80}")
...@@ -4,12 +4,11 @@ ...@@ -4,12 +4,11 @@
import csv import csv
import os import os
from datetime import datetime from datetime import datetime
from typing import Optional
import flashinfer import flashinfer
import torch import torch
from vllm.utils import round_up from vllm.utils.math_utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fn
...@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): ...@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad() @torch.no_grad()
def benchmark_decode( def benchmark_decode(
dtype: torch.dtype, dtype: torch.dtype,
quant_dtypes: tuple[ quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int, batch_size: int,
max_seq_len: int, max_seq_len: int,
num_heads: tuple[int, int] = (64, 8), num_heads: tuple[int, int] = (64, 8),
......
...@@ -4,12 +4,11 @@ ...@@ -4,12 +4,11 @@
import csv import csv
import os import os
from datetime import datetime from datetime import datetime
from typing import Optional
import flashinfer import flashinfer
import torch import torch
from vllm.utils import round_up from vllm.utils.math_utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fn
...@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): ...@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad() @torch.no_grad()
def benchmark_prefill( def benchmark_prefill(
dtype: torch.dtype, dtype: torch.dtype,
quant_dtypes: tuple[ quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int, batch_size: int,
max_seq_len: int, max_seq_len: int,
num_heads: tuple[int, int] = (64, 8), num_heads: tuple[int, int] = (64, 8),
......
...@@ -14,11 +14,11 @@ import torch ...@@ -14,11 +14,11 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul, _w8a8_triton_block_scaled_mm,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
...@@ -83,7 +83,7 @@ def w8a8_block_matmul( ...@@ -83,7 +83,7 @@ def w8a8_block_matmul(
) )
if A.dtype == torch.float8_e4m3fn: if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_block_fp8_matmul kernel = _w8a8_triton_block_scaled_mm
else: else:
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment