"tests/vscode:/vscode.git/clone" did not exist on "61b8cea3b42feab021d506e9143551de18f9165c"
Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
......@@ -93,12 +93,12 @@ def torch_sdpa_wrapper(
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
......
......@@ -2,19 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import os
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast, get_args
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
AttentionBackendEnum,
MambaAttentionBackendEnum,
)
from vllm.config.cache import CacheDType
......@@ -24,60 +19,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
"""
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* AttentionBackendEnum value if an override is specified
* None otherwise
"""
backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
if backend_name is None:
return None
if backend_name == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
"details). Please select a supported attention backend."
)
return AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: AttentionBackendEnum | None = None
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
"""
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
"""
global forced_attn_backend
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
"""
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
"""
return forced_attn_backend
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
......@@ -86,6 +27,7 @@ def get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
......@@ -97,7 +39,13 @@ def get_attn_backend(
f"Valid values are: {valid_cache_dtypes}"
)
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend
return _cached_get_attn_backend(
backend=backend_enum,
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
......@@ -105,12 +53,14 @@ def get_attn_backend(
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type,
)
@cache
def _cached_get_attn_backend(
backend,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
......@@ -118,41 +68,9 @@ def _cached_get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: AttentionBackendEnum | None = (
get_global_forced_attn_backend()
)
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
if backend_by_env_var.endswith("_VLLM_V1"):
logger.warning(
"The suffix '_VLLM_V1' in the environment variable "
"VLLM_ATTENTION_BACKEND is no longer necessary as "
"V0 backends have been deprecated. "
"Please remove this suffix from your "
"environment variable setting.",
)
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
try:
selected_backend = AttentionBackendEnum[backend_by_env_var]
except KeyError as e:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
) from e
# get device-specific attn_backend
from vllm.platforms import current_platform
sig = inspect.signature(current_platform.get_attn_backend_cls)
......@@ -163,7 +81,7 @@ def _cached_get_attn_backend(
"remove it from your plugin code."
)
attention_cls = current_platform.get_attn_backend_cls(
selected_backend,
backend,
head_size,
dtype,
kv_cache_dtype,
......@@ -172,11 +90,12 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
selected_backend,
backend,
head_size,
dtype,
kv_cache_dtype,
......@@ -184,6 +103,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
if not attention_cls:
......@@ -232,37 +152,3 @@ def _cached_get_mamba_attn_backend(
mamba_attn_backend = selected_backend.get_class()
return mamba_attn_backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: AttentionBackendEnum,
) -> Generator[None, None, None]:
"""
Globally force a vLLM attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
"""
# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
_cached_get_attn_backend.cache_clear()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -49,10 +48,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
if vllm_config.attention_config.flash_attn_version is not None:
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
......
......@@ -32,7 +32,6 @@ from typing import Any, cast
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizerBase
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
......@@ -189,7 +188,7 @@ class BenchmarkDataset(ABC):
@abstractmethod
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
......@@ -201,7 +200,7 @@ class BenchmarkDataset(ABC):
for generating a list of SampleRequest objects.
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
tokenizer (TokenizerLike): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
request_id_prefix (str): The prefix of request_id.
......@@ -380,7 +379,7 @@ def process_video(video: Any) -> Mapping[str, Any]:
def gen_prompt_decode_to_target_len(
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
token_sequence: list[int],
target_token_len: int,
max_retry: int = 10,
......@@ -468,7 +467,7 @@ class RandomDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
......@@ -580,7 +579,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float,
input_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
......@@ -626,7 +625,7 @@ class RandomDataset(BenchmarkDataset):
def generate_token_sequence(
self,
*,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
prefix_token_ids: list[int],
prefix_len: int,
vocab_size: int,
......@@ -686,7 +685,7 @@ class RandomDatasetForReranking(RandomDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
......@@ -716,7 +715,11 @@ class RandomDatasetForReranking(RandomDataset):
doc_lens, _, doc_offsets = self.get_sampling_params(
num_requests, range_ratio, doc_len_param, 0, tokenizer
)
vocab_size = tokenizer.vocab_size
prohibited_tokens = tokenizer.all_special_ids
all_tokens = np.arange(vocab_size)
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
query_prompt, query_input_len, token_mismatch_total = (
self.generate_token_sequence(
......@@ -727,6 +730,7 @@ class RandomDatasetForReranking(RandomDataset):
input_len=query_len,
offset=int(query_offsets[0]),
index=0,
allowed_tokens=allowed_tokens,
)
)
......@@ -740,6 +744,7 @@ class RandomDatasetForReranking(RandomDataset):
input_len=int(doc_lens[i]),
offset=int(doc_offsets[i]),
index=i + 1,
allowed_tokens=allowed_tokens,
)
token_mismatch_total += token_mismatch
requests.append((prompt, total_input_len))
......@@ -1077,7 +1082,7 @@ class RandomMultiModalDataset(RandomDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
......@@ -1231,7 +1236,7 @@ class ShareGPTDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
lora_path: str | None = None,
max_loras: int | None = None,
......@@ -1633,7 +1638,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
)
def get_samples(args, tokenizer) -> list[SampleRequest]:
def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
if not hasattr(args, "request_id_prefix"):
args.request_id_prefix = ""
......@@ -1842,6 +1847,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
prefix_len=args.common_prefix_len,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
......@@ -1970,7 +1976,7 @@ class CustomDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
lora_path: str | None = None,
max_loras: int | None = None,
......@@ -2100,7 +2106,7 @@ class SonnetDataset(BenchmarkDataset):
def sample(
self,
tokenizer,
tokenizer: TokenizerLike,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
input_len: int = DEFAULT_INPUT_LEN,
......@@ -2201,7 +2207,7 @@ class BurstGPTDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
max_loras: int | None = None,
lora_path: str | None = None,
......@@ -2286,7 +2292,7 @@ class ConversationDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
......@@ -2346,7 +2352,7 @@ class MultiModalConversationDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
......@@ -2415,7 +2421,7 @@ class VisionArenaDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
......@@ -2469,7 +2475,7 @@ class MMVUDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
......@@ -2530,7 +2536,7 @@ class InstructCoderDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
......@@ -2594,7 +2600,7 @@ class MTBenchDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
......@@ -2660,7 +2666,7 @@ class BlazeditDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
skip_chat_template: bool = False,
......@@ -2741,7 +2747,7 @@ class AIMODataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
request_id_prefix: str = "",
......@@ -2851,7 +2857,7 @@ class NextEditPredictionDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
......@@ -2923,7 +2929,7 @@ class ASRDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
request_id_prefix: str = "",
......@@ -3001,7 +3007,7 @@ class MLPerfDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
request_id_prefix: str = "",
......@@ -3080,7 +3086,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
suffix_len: int = DEFAULT_SUFFIX_LEN,
......@@ -3166,7 +3172,7 @@ class MMStarDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
......
......@@ -12,7 +12,6 @@ from typing import Any
import numpy as np
from tqdm import tqdm
import vllm.envs as envs
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
......@@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace):
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
raise OSError(
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
"Please set it to a valid path to use torch profiler."
)
engine_args = EngineArgs.from_cli_args(args)
if args.profile and not engine_args.profiler_config.profiler == "torch":
raise ValueError(
"The torch profiler is not enabled. Please provide profiler_config."
)
# Lazy import to avoid importing LLM when the bench command is not selected.
from vllm import LLM, SamplingParams
......@@ -144,7 +142,7 @@ def main(args: argparse.Namespace):
run_to_completion(profile_dir=None)
if args.profile:
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
profile_dir = engine_args.profiler_config.torch_profiler_dir
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
......
......@@ -36,7 +36,6 @@ from typing import Any, Literal
import aiohttp
import numpy as np
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from vllm.benchmarks.datasets import SampleRequest, add_dataset_parser, get_samples
from vllm.benchmarks.lib.endpoint_request_func import (
......@@ -47,7 +46,7 @@ from vllm.benchmarks.lib.endpoint_request_func import (
)
from vllm.benchmarks.lib.ready_checker import wait_for_endpoint
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.tokenizers import get_tokenizer
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.network_utils import join_host_port
......@@ -286,7 +285,7 @@ def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
selected_percentiles: list[float],
goodput_config_dict: dict[str, float],
) -> tuple[BenchmarkMetrics, list[int]]:
......@@ -489,7 +488,7 @@ async def benchmark(
base_url: str,
model_id: str,
model_name: str,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
input_requests: list[SampleRequest],
logprobs: int | None,
request_rate: float,
......@@ -1032,6 +1031,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
help="""Tokenizer mode:\n
- "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.\n
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- Other custom values can be supported via plugins.""",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--logprobs",
......@@ -1085,8 +1097,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
parser.add_argument(
"--save-result",
......@@ -1221,17 +1232,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Repetition penalty sampling parameter. Only has effect on "
"openai-compatible backends.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
"always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.',
sampling_group.add_argument(
"--common-prefix-len",
type=int,
default=None,
help="Common prefix length shared by all prompts (used by random dataset)",
)
parser.add_argument(
......
......@@ -9,8 +9,26 @@ class ParameterSweep(list["ParameterSweepItem"]):
@classmethod
def read_json(cls, filepath: os.PathLike):
with open(filepath, "rb") as f:
records = json.load(f)
data = json.load(f)
# Support both list and dict formats
if isinstance(data, dict):
return cls.read_from_dict(data)
return cls.from_records(data)
@classmethod
def read_from_dict(cls, data: dict[str, dict[str, object]]):
"""
Read parameter sweep from a dict format where keys are names.
Example:
{
"experiment1": {"max_tokens": 100, "temperature": 0.7},
"experiment2": {"max_tokens": 200, "temperature": 0.9}
}
"""
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
return cls.from_records(records)
@classmethod
......@@ -21,6 +39,15 @@ class ParameterSweep(list["ParameterSweepItem"]):
f"but found type: {type(records)}"
)
# Validate that all _benchmark_name values are unique if provided
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
if names and len(names) != len(set(names)):
duplicates = [name for name in names if names.count(name) > 1]
raise ValueError(
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
f"All _benchmark_name values must be unique."
)
return cls(ParameterSweepItem.from_record(record) for record in records)
......@@ -38,6 +65,18 @@ class ParameterSweepItem(dict[str, object]):
def __or__(self, other: dict[str, Any]):
return type(self)(super().__or__(other))
@property
def name(self) -> str:
"""
Get the name for this parameter sweep item.
Returns the '_benchmark_name' field if present, otherwise returns a text
representation of all parameters.
"""
if "_benchmark_name" in self:
return self["_benchmark_name"]
return self.as_text(sep="-")
# In JSON, we prefer "_"
def _iter_param_key_candidates(self, param_key: str):
# Inner config arguments are not converted by the CLI
......@@ -63,29 +102,57 @@ class ParameterSweepItem(dict[str, object]):
def has_param(self, param_key: str) -> bool:
return any(k in self for k in self._iter_param_key_candidates(param_key))
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
"""
Normalize a key-value pair into command-line arguments.
Returns a list containing either:
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
- Two elements for key-value pairs (e.g., ['--key', 'value'])
"""
if isinstance(v, bool):
# For nested params (containing "."), use =true/false syntax
if "." in k:
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
else:
return [self._normalize_cmd_key(k if v else "no-" + k)]
else:
return [self._normalize_cmd_key(k), str(v)]
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
cmd = list(cmd)
for k, v in self.items():
# Skip the '_benchmark_name' field, not a parameter
if k == "_benchmark_name":
continue
# Serialize dict values as JSON
if isinstance(v, dict):
v = json.dumps(v)
for k_candidate in self._iter_cmd_key_candidates(k):
try:
k_idx = cmd.index(k_candidate)
if isinstance(v, bool):
cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k)
# Replace existing parameter
normalized = self._normalize_cmd_kv_pair(k, v)
if len(normalized) == 1:
# Boolean flag
cmd[k_idx] = normalized[0]
else:
cmd[k_idx + 1] = str(v)
# Key-value pair
cmd[k_idx] = normalized[0]
cmd[k_idx + 1] = normalized[1]
break
except ValueError:
continue
else:
if isinstance(v, bool):
cmd.append(self._normalize_cmd_key(k if v else "no-" + k))
else:
cmd.extend([self._normalize_cmd_key(k), str(v)])
# Add new parameter
cmd.extend(self._normalize_cmd_kv_pair(k, v))
return cmd
def as_text(self, sep: str = ", ") -> str:
return sep.join(f"{k}={v}" for k, v in self.items())
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")
......@@ -65,6 +65,18 @@ class PlotEqualTo(PlotFilterBase):
return df[df[self.var] == target]
@dataclass
class PlotNotEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try:
target = float(self.target)
except ValueError:
target = self.target
return df[df[self.var] != target]
@dataclass
class PlotLessThan(PlotFilterBase):
@override
......@@ -96,6 +108,7 @@ class PlotGreaterThanOrEqualTo(PlotFilterBase):
# NOTE: The ordering is important! Match longer op_keys first
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
"==": PlotEqualTo,
"!=": PlotNotEqualTo,
"<=": PlotLessThanOrEqualTo,
">=": PlotGreaterThanOrEqualTo,
"<": PlotLessThan,
......@@ -167,6 +180,27 @@ def _json_load_bytes(path: Path) -> list[dict[str, object]]:
return json.load(f)
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
"""
Convert string values "inf", "-inf", and "nan" to their float equivalents.
This handles the case where JSON serialization represents inf/nan as strings.
"""
converted_data = []
for record in data:
converted_record = {}
for key, value in record.items():
if isinstance(value, str):
if value in ["inf", "-inf", "nan"]:
converted_record[key] = float(value)
else:
converted_record[key] = value
else:
converted_record[key] = value
converted_data.append(converted_record)
return converted_data
def _get_metric(run_data: dict[str, object], metric_key: str):
try:
return run_data[metric_key]
......@@ -178,12 +212,15 @@ def _get_group(run_data: dict[str, object], group_keys: list[str]):
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]):
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
parts = list[str]()
# Start with figure name (always provided, defaults to "FIGURE")
parts.append(fig_name)
# Always append group data if present
if group:
parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group)))
else:
parts.append("figure")
parts.extend(f"{k}={v}" for k, v in group)
return fig_dir / sanitize_filename("-".join(parts) + ".png")
......@@ -217,6 +254,10 @@ def _plot_fig(
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
fig_name: str,
error_bars: bool,
fig_height: float,
fig_dpi: int,
):
fig_group, fig_data = fig_group_data
......@@ -230,7 +271,7 @@ def _plot_fig(
for _, row_data in row_groups
)
fig_path = _get_fig_path(fig_dir, fig_group)
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
print("[BEGIN FIGURE]")
print(f"Group: {dict(fig_group)}")
......@@ -241,6 +282,8 @@ def _plot_fig(
print("[END FIGURE]")
return
# Convert string "inf", "-inf", and "nan" to their float equivalents
fig_data = _convert_inf_nan_strings(fig_data)
df = pd.DataFrame.from_records(fig_data)
if var_x not in df.columns:
......@@ -275,6 +318,10 @@ def _plot_fig(
df = filter_by.apply(df)
df = bin_by.apply(df)
# Sort by curve_by columns alphabetically for consistent legend ordering
if curve_by:
df = df.sort_values(by=curve_by)
df["row_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in row_by],
......@@ -293,7 +340,7 @@ def _plot_fig(
else "(All)"
)
g = sns.FacetGrid(df, row="row_group", col="col_group")
g = sns.FacetGrid(df, row="row_group", col="col_group", height=fig_height)
if row_by and col_by:
g.set_titles("{row_name}\n{col_name}")
......@@ -320,6 +367,7 @@ def _plot_fig(
style=style,
size=size,
markers=True,
errorbar="sd" if error_bars else None,
)
g.add_legend(title=hue)
......@@ -339,11 +387,12 @@ def _plot_fig(
y=var_y,
hue="curve_group",
markers=True,
errorbar="sd" if error_bars else None,
)
g.add_legend()
g.savefig(fig_path)
g.savefig(fig_path, dpi=fig_dpi)
plt.close(g.figure)
print("[END FIGURE]")
......@@ -364,6 +413,10 @@ def plot(
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
fig_name: str = "FIGURE",
error_bars: bool = True,
fig_height: float = 6.4,
fig_dpi: int = 300,
):
all_data = [
run_data
......@@ -398,6 +451,10 @@ def plot(
scale_x=scale_x,
scale_y=scale_y,
dry_run=dry_run,
fig_name=fig_name,
error_bars=error_bars,
fig_height=fig_height,
fig_dpi=fig_dpi,
),
fig_groups,
)
......@@ -419,6 +476,10 @@ class SweepPlotArgs:
scale_x: str | None
scale_y: str | None
dry_run: bool
fig_name: str = "FIGURE"
error_bars: bool = True
fig_height: float = 6.4
fig_dpi: int = 300
parser_name: ClassVar[str] = "plot"
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
......@@ -448,6 +509,10 @@ class SweepPlotArgs:
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
fig_name=args.fig_name,
error_bars=not args.no_error_bars,
fig_height=args.fig_height,
fig_dpi=args.fig_dpi,
)
@classmethod
......@@ -541,6 +606,32 @@ class SweepPlotArgs:
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--fig-name",
type=str,
default="FIGURE",
help="Name prefix for the output figure file. "
"Group data is always appended when present. "
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
)
parser.add_argument(
"--no-error-bars",
action="store_true",
help="If set, disables error bars on the plot. "
"By default, error bars are shown.",
)
parser.add_argument(
"--fig-height",
type=float,
default=6.4,
help="Height of each subplot in inches. Default: 6.4",
)
parser.add_argument(
"--fig-dpi",
type=int,
default=300,
help="Resolution of the output figure in dots per inch. Default: 300",
)
parser.add_argument(
"--dry-run",
action="store_true",
......@@ -566,6 +657,10 @@ def run_main(args: SweepPlotArgs):
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
fig_name=args.fig_name,
error_bars=args.error_bars,
fig_height=args.fig_height,
fig_dpi=args.fig_dpi,
)
......
......@@ -138,9 +138,9 @@ def _get_comb_base_path(
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
parts.extend(("SERVE-", serve_comb.name))
if bench_comb:
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
parts.extend(("BENCH-", bench_comb.name))
return output_dir / sanitize_filename("-".join(parts))
......@@ -345,8 +345,9 @@ class SweepServeArgs:
"--serve-params",
type=str,
default=None,
help="Path to JSON file containing a list of parameter combinations "
"for the `vllm serve` command. "
help="Path to JSON file containing parameter combinations "
"for the `vllm serve` command. Can be either a list of dicts or a dict "
"where keys are benchmark names. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
......@@ -354,8 +355,9 @@ class SweepServeArgs:
"--bench-params",
type=str,
default=None,
help="Path to JSON file containing a list of parameter combinations "
"for the `vllm bench serve` command. "
help="Path to JSON file containing parameter combinations "
"for the `vllm bench serve` command. Can be either a list of dicts or "
"a dict where keys are benchmark names. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
......
......@@ -14,7 +14,7 @@ from typing import Any
import torch
import uvloop
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (
AIMODataset,
......@@ -35,6 +35,7 @@ from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.utils.async_utils import merge_async_iterators
......@@ -246,12 +247,15 @@ async def run_vllm_async(
def run_hf(
requests: list[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
n: int,
max_batch_size: int,
trust_remote_code: bool,
disable_detokenize: bool = False,
) -> float:
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
"the hf backend only supports HF tokenizers"
)
llm = AutoModelForCausalLM.from_pretrained(
model, dtype=torch.float16, trust_remote_code=trust_remote_code
)
......@@ -651,8 +655,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--profile",
action="store_true",
default=False,
help="Use Torch Profiler. The env variable "
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
# prefix repetition dataset
......@@ -692,15 +695,21 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace):
if args.tokenizer is None:
args.tokenizer = args.model
validate_args(args)
if args.seed is None:
args.seed = 0
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
if (
args.backend == "hf" or args.backend == "mii"
) and args.tokenizer_mode == "auto":
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
# for hf and mii backends, we use hf tokenizer
args.tokenizer_mode = "hf"
tokenizer = get_tokenizer(
args.tokenizer,
tokenizer_mode=args.tokenizer_mode,
trust_remote_code=args.trust_remote_code,
)
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
......
......@@ -26,7 +26,8 @@ from vllm.compilation.partition_rules import (
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.utils import hash_factors
from vllm.config.compilation import DynamicShapesType
from vllm.config.utils import Range, hash_factors
from vllm.logger import init_logger
from vllm.logging_utils import lazy
from vllm.platforms import current_platform
......@@ -90,7 +91,7 @@ class CompilerManager:
"""
def __init__(self, compilation_config: CompilationConfig):
self.cache: dict[tuple[int | None, int, str], Any] = dict()
self.cache: dict[tuple[Range, int, str], Any] = dict()
self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)
......@@ -99,11 +100,11 @@ class CompilerManager:
return self.compiler.compute_hash(vllm_config)
@contextmanager
def compile_context(self, runtime_shape: int | None = None):
def compile_context(self, compile_range: Range):
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape):
with pass_context(compile_range):
if self.compilation_config.use_inductor_graph_partition:
with inductor_partition_rule_context(
self.compilation_config.splitting_ops
......@@ -159,26 +160,18 @@ class CompilerManager:
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable | None:
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, runtime_shape
handle, graph, example_inputs, graph_index, compile_range
)
if runtime_shape is None:
logger.debug(
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via handle %s",
graph_index,
str(runtime_shape),
str(compile_range),
self.compiler.name,
handle,
)
......@@ -190,9 +183,9 @@ class CompilerManager:
example_inputs,
additional_inductor_config,
compilation_config: CompilationConfig,
compile_range: Range,
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: int | None = None,
) -> Any:
if graph_index == 0:
# before compiling the first graph, record the start time
......@@ -204,7 +197,7 @@ class CompilerManager:
compiled_graph = None
# try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
if compiled_graph is not None:
if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time.
......@@ -212,17 +205,10 @@ class CompilerManager:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info(
"Directly load the compiled graph(s) for dynamic shape "
"from the cache, took %.3f s",
elapsed,
)
else:
logger.info(
"Directly load the compiled graph(s) for shape %s "
"Directly load the compiled graph(s) for compile range %s "
"from the cache, took %.3f s",
str(runtime_shape),
str(compile_range),
elapsed,
)
return compiled_graph
......@@ -233,14 +219,15 @@ class CompilerManager:
# Let compile_fx generate a key for us
maybe_key = None
else:
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
with self.compile_context(runtime_shape):
maybe_key = "artifact_compile_range_"
maybe_key += f"{compile_range.start}_{compile_range.end}"
maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range):
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
runtime_shape,
compile_range,
maybe_key,
)
......@@ -248,33 +235,19 @@ class CompilerManager:
# store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True
if graph_index == 0:
# adds some info logging for the first graph
if runtime_shape is None:
logger.info_once(
"Cache the graph for dynamic shape for later use", scope="local"
"Cache the graph of compile range %s for later use",
str(compile_range),
)
else:
logger.info_once(
"Cache the graph of shape %s for later use",
str(runtime_shape),
scope="local",
)
if runtime_shape is None:
logger.debug(
"Store the %s-th graph for dynamic shape from %s via handle %s",
"Store the %s-th graph for compile range%s from %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Store the %s-th graph for shape %s from %s via handle %s",
graph_index,
str(runtime_shape),
str(compile_range),
self.compiler.name,
handle,
)
......@@ -284,16 +257,9 @@ class CompilerManager:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info_once(
"Compiling a graph for dynamic shape takes %.2f s",
elapsed,
scope="local",
)
else:
logger.info_once(
"Compiling a graph for shape %s takes %.2f s",
runtime_shape,
"Compiling a graph for compile range %s takes %.2f s",
str(compile_range),
elapsed,
scope="local",
)
......@@ -402,6 +368,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.extra_traceback = False
def run(self, *args):
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
......@@ -416,27 +383,17 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
kwargs: dict[str, Any],
) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)
if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time
compiled_graph_for_dynamic_shape = (
self.vllm_backend.compiler_manager.compile(
submod,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
)
)
# Lazy import here to avoid circular import
from .piecewise_backend import PiecewiseBackend
......@@ -446,7 +403,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
index,
len(self.compile_submod_names),
sym_shape_indices,
compiled_graph_for_dynamic_shape,
self.vllm_backend,
)
......@@ -586,8 +542,13 @@ class VllmBackend:
)
else:
# Config should automatically wrap all inductor passes
assert isinstance(self.inductor_config[self.pass_key], InductorPass)
self.pass_manager.add(self.inductor_config[self.pass_key])
assert isinstance(
self.compilation_config.inductor_compile_config[self.pass_key],
InductorPass,
)
self.pass_manager.add(
self.compilation_config.inductor_compile_config[self.pass_key]
)
self.inductor_config[self.pass_key] = self.pass_manager
def __call__(
......@@ -746,11 +707,44 @@ class VllmBackend:
if not item.is_splitting_graph
]
# Extract fake values from the graph to use them when needed.
all_fake_values = []
for i in graph.graph.find_nodes(op="placeholder"):
all_fake_values.append(i.meta["example_value"])
fake_args = [
all_fake_values[i] if isinstance(t, torch.Tensor) else t
for i, t in enumerate(example_inputs)
]
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(
self.split_gm, submod_names_to_compile, self.vllm_config, self
).run(*example_inputs)
).run(*fake_args)
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
if (
self.compilation_config.dynamic_shapes_config.evaluate_guards
and self.compilation_config.dynamic_shapes_config.type
== DynamicShapesType.BACKED
):
from torch.utils._sympy.value_ranges import ValueRanges
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
# torch.compile will specialize for 0/1 inputs or otherwise guards that
# shape is >= 2. This is because it's really hard not to hit a check
# against 0/1. When we evaluate shape guards, we exclude checking those
# guards (We would fail always otherwise).
# We avoid that by updating the ranges of backed sizes when the min is
# 2 for any, we assume it's 0.
for s, r in fake_mode.shape_env.var_to_range.items():
if r.lower == 2:
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
......@@ -779,15 +773,6 @@ class VllmBackend:
graph, example_inputs, self.prefix, self.split_gm
)
# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
fake_args = [
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in example_inputs
]
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.
......
......@@ -10,6 +10,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
......@@ -431,7 +432,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool:
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
......@@ -441,7 +442,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
):
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
return compile_range.is_single_size() and compile_range.end % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
......@@ -505,18 +506,21 @@ if flashinfer_comm is not None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
if num_tokens <= max_token_num:
max_tensor_size = max_token_num * hidden_size * element_size
assert current_tensor_size <= max_tensor_size, (
f"Current tensor size {current_tensor_size} is larger than "
f"max token num {max_token_num} * hidden size {hidden_size} * "
f"element size {element_size}"
)
device_capability = current_platform.get_device_capability().to_int()
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, {}
).get(world_size, None)
# Use one shot if no max size for one shot is specified
# Use one shot if no max size is specified
use_oneshot = (
max_one_shot_size_mb is None
or current_tensor_size <= max_one_shot_size_mb * MiB
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
)
assert _FI_WORKSPACE_TENSOR is not None, (
......@@ -556,40 +560,6 @@ if flashinfer_comm is not None:
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=scale_factor,
)
else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
if scale_factor is not None and scale_out is None:
# Do fused rms norm static fp8 quant fused op
if norm_out is None:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
quant_out,
allreduce_out,
residual,
rms_gamma,
scale_factor,
rms_eps,
)
else:
torch.ops._C.rms_norm_static_fp8_quant(
quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps
)
else:
if norm_out is None:
torch.ops._C.fused_add_rms_norm(
allreduce_out, residual, rms_gamma, rms_eps
)
norm_out = allreduce_out
else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
if scale_factor is not None and scale_out is not None:
torch.ops._C.scaled_fp4_quant(
quant_out, norm_out, scale_out, scale_factor
)
if scale_factor is None or norm_out is not None:
# we need to return allreduce output
# in cases of non quant fused AR + RMS norm
# and fused AR + RMS norm + quant without fused add
allreduce_in.copy_(allreduce_out)
def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor,
......@@ -1106,11 +1076,15 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size <= 1:
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
return
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="all_reduce_fusion_pass"
)
if config.model_config is None:
logger.warning_once(
"AllReduce fusion pass is disabled for missing model_config."
)
return
self.hidden_dim = config.model_config.get_hidden_size()
self.group = get_tp_group().device_group
......@@ -1128,7 +1102,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
if max_size is None:
# Flashinfer doesn't support current world size
logger.warning(
"Flashinfer allreduce fusion is not supported for world size %s",
"Flashinfer allreduce fusion is not supported for world size %s"
" or max size is not provided",
self.tp_size,
)
return
......@@ -1216,6 +1191,12 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.disabled = False
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.")
return False
return compile_range.end <= self.max_token_num
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
if self.disabled:
......
......@@ -15,6 +15,7 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
......@@ -63,16 +64,16 @@ class CompilerInterface:
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
the `example_inputs` have a dynamic shape. Otherwise, the
`runtime_shape` specifies the shape of the inputs. Right now we only
support one variable shape for all inputs, which is the batchsize
(number of tokens) during inference.
with a range. The `compile_range` specifies the range of the inputs,
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
or a range [5, 8].
Right now we only support one variable in ranges for all inputs,
which is the batchsize (number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
......@@ -98,7 +99,7 @@ class CompilerInterface:
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
"""
Load the compiled function from the handle.
......@@ -212,20 +213,20 @@ class InductorStandaloneAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, runtime_shape)
set_inductor_config(current_config, compile_range)
set_functorch_config()
if isinstance(runtime_shape, int):
if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs"
else:
dynamic_shapes = "from_tracing_context"
dynamic_shapes = "from_graph"
from torch._inductor import standalone_compile
......@@ -235,7 +236,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config},
)
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
......@@ -251,7 +251,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
......@@ -315,7 +315,7 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
......@@ -329,7 +329,7 @@ class InductorAdaptor(CompilerInterface):
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, runtime_shape)
set_inductor_config(current_config, compile_range)
set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
......@@ -512,7 +512,7 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: int | None = None,
compile_range: Range,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
......@@ -608,9 +608,9 @@ class InductorAdaptor(CompilerInterface):
return contextlib.nullcontext()
def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
def set_inductor_config(config, compile_range: Range):
if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
config["coordinate_descent_tuning"] = (
......@@ -630,7 +630,7 @@ class EagerAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: int | None = None,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
compilation_counter.num_eager_compiles += 1
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
......@@ -22,6 +23,99 @@ from vllm.utils.torch_utils import weak_ref_tensors
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class CUDAGraphStat:
num_unpadded_tokens: int
num_padded_tokens: int
num_paddings: int
runtime_mode: str
class CUDAGraphLogging:
"""Aggregate and log cudagraph metrics"""
COLUMN_HEADERS = [
"Unpadded Tokens",
"Padded Tokens",
"Num Paddings",
"Runtime Mode",
"Count",
]
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
self.reset()
self.cg_mode = str(cg_mode)
self.cg_capture_sizes = str(cg_capture_sizes or [])
self.settings_header = (
"**CUDAGraph Config Settings:**\n\n"
f"- Mode: {self.cg_mode}\n"
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
"**CUDAGraph Stats:**\n\n"
)
def reset(self):
self.stats = []
def observe(self, cudagraph_stat: CUDAGraphStat):
self.stats.append(cudagraph_stat)
def generate_metric_table(self) -> str:
stats_counts = Counter(self.stats)
# Convert stats to rows of strings, in descending order of observed frequencies
rows = []
for stat, count in sorted(
stats_counts.items(), key=lambda item: item[1], reverse=True
):
rows.append(
[
str(stat.num_unpadded_tokens),
str(stat.num_padded_tokens),
str(stat.num_paddings),
stat.runtime_mode,
str(count),
]
)
# Calculate column widths (max of header and data)
col_widths = []
for i, header_text in enumerate(self.COLUMN_HEADERS):
max_width = len(header_text)
for row in rows:
max_width = max(max_width, len(row[i]))
col_widths.append(max_width)
table_header_list = [
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
]
table_header = "| " + " | ".join(table_header_list) + " |\n"
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
# Create data rows with proper alignment
data_rows = []
for row in rows:
formatted_row = [
str(val).ljust(width) for val, width in zip(row, col_widths)
]
data_rows.append("| " + " | ".join(formatted_row) + " |")
return (
self.settings_header
+ table_header
+ table_separator
+ "\n".join(data_rows)
+ "\n"
)
def log(self, log_fn=logger.info):
if not self.stats:
return
log_fn(self.generate_metric_table())
self.reset()
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
......
......@@ -392,7 +392,6 @@ def _support_torch_compile(
factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_aot_compile",
......@@ -409,8 +408,11 @@ def _support_torch_compile(
open(aot_compilation_path, "rb") as f,
):
start_monitoring_torch_compile(self.vllm_config)
loaded_fn = torch.compiler.load_compiled_function(f)
loaded_fn = torch.compiler.load_compiled_function(
f, f_globals=self.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
loaded_fn.disable_guard_check()
self.aot_compiled_fn = loaded_fn
except Exception as e:
......@@ -433,7 +435,6 @@ def _support_torch_compile(
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
_mark_dynamic_inputs(
self,
......
......@@ -103,6 +103,19 @@ class FixFunctionalizationPass(VllmInductorPass):
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
and at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.
......
......@@ -15,13 +15,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear_for_nk,
)
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
......@@ -58,6 +67,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
class FusedRMSQuantKey(NamedTuple):
......@@ -90,6 +102,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
# FusedRMSQuantKey(
# kFp8DynamicTokenSym, True
# ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
# FusedRMSQuantKey(
# kFp8Dynamic128Sym, False
# ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
# FusedRMSQuantKey(
# kFp8Dynamic128Sym, True
# ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
# FusedRMSQuantKey(
# kFp8Dynamic64Sym, False
# ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
# FusedRMSQuantKey(
# kFp8Dynamic64Sym, True
# ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
}
......@@ -100,6 +124,15 @@ class RMSNormQuantPattern:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
self.model_dtype,
config.model_config.hf_config.intermediate_size,
config.model_config.hf_config.hidden_size,
)
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
......@@ -108,7 +141,9 @@ class RMSNormQuantPattern:
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(key.quant)
self.quant_matcher = MatcherQuantFP8(
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
......@@ -218,6 +253,120 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
)
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
)
# result, residual, scale
return at[1], at[3], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(input: torch.Tensor, weight: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
)
# result, scale
return at[1], at[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(
self,
......@@ -340,6 +489,27 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
......@@ -366,9 +536,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
def uuid(self) -> Any:
return self.hash_source(
self,
RMSNormGroupQuantPattern,
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
FusedAddRMSNormGroupQuantPattern,
)
......@@ -75,7 +75,7 @@ def find_op_nodes(
return
assert isinstance(op, OpOverload)
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import functools
import hashlib
import inspect
......@@ -8,7 +10,7 @@ import json
import types
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any
from typing import TYPE_CHECKING, Any
import torch
from torch import fx
......@@ -16,6 +18,9 @@ from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING:
from vllm.config.utils import Range
if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass
else:
......@@ -28,8 +33,8 @@ _pass_context = None
class PassContext:
def __init__(self, runtime_shape: int | None):
self.runtime_shape = runtime_shape
def __init__(self, compile_range: Range):
self.compile_range: Range = compile_range
def get_pass_context() -> PassContext:
......@@ -39,13 +44,13 @@ def get_pass_context() -> PassContext:
@contextmanager
def pass_context(runtime_shape: int | None):
def pass_context(compile_range: Range):
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(runtime_shape)
_pass_context = PassContext(compile_range)
try:
yield
finally:
......@@ -96,7 +101,7 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable(self, shape: int | None):
def is_applicable_for_range(self, compile_range: Range):
return True
......
......@@ -13,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
_normalize_quant_group_shape,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
......@@ -35,6 +37,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
......@@ -224,12 +230,20 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
class MatcherQuantFP8(MatcherCustomOp):
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
def __init__(
self,
quant_key: QuantKey,
enabled: bool | None = None,
use_col_major_scales: bool = False,
use_e8m0: bool = False,
):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.use_col_major_scales = use_col_major_scales
self.use_e8m0 = use_e8m0
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
......@@ -248,6 +262,27 @@ class MatcherQuantFP8(MatcherCustomOp):
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.group_shape.is_per_group():
assert scale is None
scale = self.make_scale(input, transposed=self.use_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.QUANT_OP,
input=input,
output_q=result,
output_s=scale,
group_size=self.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.use_e8m0,
)
return result, scale
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
......@@ -269,7 +304,7 @@ class MatcherQuantFP8(MatcherCustomOp):
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
def make_scale(self, input: torch.Tensor):
def make_scale(self, input: torch.Tensor, transposed: bool = False):
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
......@@ -277,6 +312,11 @@ class MatcherQuantFP8(MatcherCustomOp):
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
if transposed:
scale_shape = tuple(reversed(scale_shape))
return torch.empty(
scale_shape, device=input.device, dtype=torch.float32
).permute(-1, -2)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment