Unverified Commit 7291d1b2 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix kernel benchmark (#33752)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 987506bc
......@@ -13,6 +13,7 @@ from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
......@@ -291,6 +292,7 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print()
@default_vllm_config()
def main():
torch.set_default_device("cuda")
bench_params = get_bench_params()
......
......@@ -7,6 +7,7 @@ import itertools
import torch
import vllm.model_executor.layers.activation # noqa F401
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.custom_op import op_registry
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
......@@ -18,6 +19,7 @@ intermediate_size = [3072, 9728, 12288]
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
@default_vllm_config()
def benchmark_activation(
batch_size: int,
seq_len: int,
......
......@@ -8,6 +8,7 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
)
......@@ -40,6 +41,7 @@ DEEPSEEK_V3_SHAPES = [
]
@default_vllm_config()
def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
"""Build runner function for w8a8 block fp8 matmul."""
factor_for_scale = 1e-2
......
......@@ -5,12 +5,14 @@ import time
import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
@torch.inference_mode()
@default_vllm_config()
def main(
num_tokens: int,
hidden_size: int,
......
......@@ -36,6 +36,7 @@ from typing import Any
import numpy as np
import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.transformers_utils.config import get_config
from vllm.utils.argparse_utils import FlexibleArgumentParser
......@@ -78,6 +79,7 @@ def calculate_stats(times: list[float]) -> dict[str, float]:
}
@default_vllm_config()
def benchmark_mrope(
model_name: str,
num_tokens: int,
......
......@@ -7,6 +7,7 @@ from unittest.mock import patch
import pandas as pd
import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
......@@ -84,6 +85,7 @@ def calculate_diff(
configs = []
@default_vllm_config()
def benchmark_quantization(
batch_size,
hidden_size,
......
......@@ -5,6 +5,7 @@ import itertools
import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
......@@ -29,6 +30,7 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
args={},
)
)
@default_vllm_config()
def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16
max_position = 8192
......
......@@ -5,6 +5,7 @@ import argparse
import json
import math
import os
from contextlib import contextmanager
from typing import Any
......@@ -117,3 +118,14 @@ def write_to_json(filename: str, records: list) -> None:
cls=InfEncoder,
default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
)
@contextmanager
def default_vllm_config():
"""Set a default VllmConfig for cases that directly test CustomOps or pathways
that use get_current_vllm_config() outside of a full engine context.
"""
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
yield
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment