Unverified Commit 7291d1b2 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix kernel benchmark (#33752)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 987506bc
...@@ -13,6 +13,7 @@ from torch.utils.benchmark import Measurement as TMeasurement ...@@ -13,6 +13,7 @@ from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm from tqdm import tqdm
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
...@@ -291,6 +292,7 @@ def print_timers(timers: Iterable[TMeasurement]): ...@@ -291,6 +292,7 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print() compare.print()
@default_vllm_config()
def main(): def main():
torch.set_default_device("cuda") torch.set_default_device("cuda")
bench_params = get_bench_params() bench_params = get_bench_params()
......
...@@ -7,6 +7,7 @@ import itertools ...@@ -7,6 +7,7 @@ import itertools
import torch import torch
import vllm.model_executor.layers.activation # noqa F401 import vllm.model_executor.layers.activation # noqa F401
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.custom_op import op_registry from vllm.model_executor.custom_op import op_registry
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -18,6 +19,7 @@ intermediate_size = [3072, 9728, 12288] ...@@ -18,6 +19,7 @@ intermediate_size = [3072, 9728, 12288]
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
@default_vllm_config()
def benchmark_activation( def benchmark_activation(
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
......
...@@ -8,6 +8,7 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0" ...@@ -8,6 +8,7 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
) )
...@@ -40,6 +41,7 @@ DEEPSEEK_V3_SHAPES = [ ...@@ -40,6 +41,7 @@ DEEPSEEK_V3_SHAPES = [
] ]
@default_vllm_config()
def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
"""Build runner function for w8a8 block fp8 matmul.""" """Build runner function for w8a8 block fp8 matmul."""
factor_for_scale = 1e-2 factor_for_scale = 1e-2
......
...@@ -5,12 +5,14 @@ import time ...@@ -5,12 +5,14 @@ import time
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
@torch.inference_mode() @torch.inference_mode()
@default_vllm_config()
def main( def main(
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
......
...@@ -36,6 +36,7 @@ from typing import Any ...@@ -36,6 +36,7 @@ from typing import Any
import numpy as np import numpy as np
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -78,6 +79,7 @@ def calculate_stats(times: list[float]) -> dict[str, float]: ...@@ -78,6 +79,7 @@ def calculate_stats(times: list[float]) -> dict[str, float]:
} }
@default_vllm_config()
def benchmark_mrope( def benchmark_mrope(
model_name: str, model_name: str,
num_tokens: int, num_tokens: int,
......
...@@ -7,6 +7,7 @@ from unittest.mock import patch ...@@ -7,6 +7,7 @@ from unittest.mock import patch
import pandas as pd import pandas as pd
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton from vllm.triton_utils import triton
...@@ -84,6 +85,7 @@ def calculate_diff( ...@@ -84,6 +85,7 @@ def calculate_diff(
configs = [] configs = []
@default_vllm_config()
def benchmark_quantization( def benchmark_quantization(
batch_size, batch_size,
hidden_size, hidden_size,
......
...@@ -5,6 +5,7 @@ import itertools ...@@ -5,6 +5,7 @@ import itertools
import torch import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -29,6 +30,7 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device): ...@@ -29,6 +30,7 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
args={}, args={},
) )
) )
@default_vllm_config()
def benchmark(batch_size, seq_len, num_heads, provider): def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16 dtype = torch.bfloat16
max_position = 8192 max_position = 8192
......
...@@ -5,6 +5,7 @@ import argparse ...@@ -5,6 +5,7 @@ import argparse
import json import json
import math import math
import os import os
from contextlib import contextmanager
from typing import Any from typing import Any
...@@ -117,3 +118,14 @@ def write_to_json(filename: str, records: list) -> None: ...@@ -117,3 +118,14 @@ def write_to_json(filename: str, records: list) -> None:
cls=InfEncoder, cls=InfEncoder,
default=lambda o: f"<{type(o).__name__} is not JSON serializable>", default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
) )
@contextmanager
def default_vllm_config():
"""Set a default VllmConfig for cases that directly test CustomOps or pathways
that use get_current_vllm_config() outside of a full engine context.
"""
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
yield
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment