Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
...@@ -43,7 +43,7 @@ def initialize_kv_cache(runner: GPUModelRunner): ...@@ -43,7 +43,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
device=runner.device, device=runner.device,
pin_memory=runner.pin_memory, pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(), vocab_size=runner.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config, block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
) )
runner.initialize_attn_backend(kv_cache_config) runner.initialize_attn_backend(kv_cache_config)
......
# SPDX-License-Identifier: Apache-2.0
import subprocess
import sys
import regex as re
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
# the way allowed to import triton
ALLOWED_LINES = {
"from vllm.triton_utils import triton",
"from vllm.triton_utils import tl",
"from vllm.triton_utils import tl, triton",
}
def is_forbidden_import(line: str) -> bool:
stripped = line.strip()
return bool(
FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
def parse_diff(diff: str) -> list[str]:
violations = []
current_file = None
current_lineno = None
for line in diff.splitlines():
if line.startswith("+++ b/"):
current_file = line[6:]
elif line.startswith("@@"):
match = re.search(r"\+(\d+)", line)
if match:
current_lineno = int(
match.group(1)) - 1 # next "+ line" is here
elif line.startswith("+") and not line.startswith("++"):
current_lineno += 1
code_line = line[1:]
if is_forbidden_import(code_line):
violations.append(
f"{current_file}:{current_lineno}: {code_line.strip()}")
return violations
def get_diff(diff_type: str) -> str:
if diff_type == "staged":
return subprocess.check_output(
["git", "diff", "--cached", "--unified=0"], text=True)
elif diff_type == "unstaged":
return subprocess.check_output(["git", "diff", "--unified=0"],
text=True)
else:
raise ValueError(f"Unknown diff_type: {diff_type}")
def main():
all_violations = []
for diff_type in ["staged", "unstaged"]:
try:
diff_output = get_diff(diff_type)
violations = parse_diff(diff_output)
all_violations.extend(violations)
except subprocess.CalledProcessError as e:
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
if all_violations:
print("❌ Forbidden direct `import triton` detected."
" ➤ Use `from vllm.triton_utils import triton` instead.\n")
for v in all_violations:
print(f"❌ {v}")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import subprocess
from pathlib import Path
import regex as re
FORBIDDEN_PATTERNS = re.compile(
r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)')
ALLOWED_PATTERNS = [
re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'),
re.compile(r'^\s*import\s+regex\s*$'),
]
def get_staged_python_files() -> list[str]:
try:
result = subprocess.run(
['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'],
capture_output=True,
text=True,
check=True)
files = result.stdout.strip().split(
'\n') if result.stdout.strip() else []
return [f for f in files if f.endswith('.py')]
except subprocess.CalledProcessError:
return []
def is_forbidden_import(line: str) -> bool:
line = line.strip()
return bool(
FORBIDDEN_PATTERNS.match(line)
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS))
def check_file(filepath: str) -> list[tuple[int, str]]:
violations = []
try:
with open(filepath, encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
if is_forbidden_import(line):
violations.append((line_num, line.strip()))
except (OSError, UnicodeDecodeError):
pass
return violations
def main() -> int:
files = get_staged_python_files()
if not files:
return 0
total_violations = 0
for filepath in files:
if not Path(filepath).exists():
continue
violations = check_file(filepath)
if violations:
print(f"\n{filepath}:")
for line_num, line in violations:
print(f" Line {line_num}: {line}")
total_violations += 1
if total_violations > 0:
print(f"\n💡 Found {total_violations} violation(s).")
print("❌ Please replace 'import re' with 'import regex as re'")
print(
" Also replace 'from re import ...' with 'from regex import ...'"
) # noqa: E501
print("✅ Allowed imports:")
print(" - import regex as re")
print(" - import regex") # noqa: E501
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())
#!/bin/bash
# Usage: ./install_nixl.sh [--force]
FORCE=false
if [ "$1" == "--force" ]; then
FORCE=true
fi
SUDO=false
if command -v sudo >/dev/null 2>&1 && sudo -n true 2>/dev/null; then
SUDO=true
fi
ARCH=$(uname -m)
ROOT_DIR="/usr/local"
mkdir -p "$ROOT_DIR"
GDR_HOME="$ROOT_DIR/gdrcopy"
UCX_HOME="$ROOT_DIR/ucx"
NIXL_HOME="$ROOT_DIR/nixl"
CUDA_HOME=/usr/local/cuda
export PATH="$GDR_HOME/bin:$UCX_HOME/bin:$NIXL_HOME/bin:$PATH"
export LD_LIBRARY_PATH="$GDR_HOME/lib:$UCX_HOME/lib:$NIXL_HOME/lib/$ARCH-linux-gnu:$LD_LIBRARY_PATH"
TEMP_DIR="nixl_installer"
mkdir -p "$TEMP_DIR"
cd "$TEMP_DIR"
pip install meson ninja pybind11
if [ ! -e "/dev/gdrdrv" ] || [ "$FORCE" = true ]; then
echo "Installing gdrcopy\n"
wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.5.tar.gz
tar xzf v2.5.tar.gz; rm v2.5.tar.gz
cd gdrcopy-2.5
make prefix=$GDR_HOME CUDA=$CUDA_HOME all install
if $SUDO; then
echo "Running insmod.sh with sudo"
sudo ./insmod.sh
else
echo "Skipping insmod.sh - sudo not available"
echo "Please run 'sudo ./gdrcopy-2.5/insmod.sh' manually if needed"
fi
cd ..
else
echo "Found /dev/gdrdrv. Skipping gdrcopy installation"
fi
if ! command -v ucx_info &> /dev/null || [ "$FORCE" = true ]; then
echo "Installing UCX"
wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz
tar xzf ucx-1.18.0.tar.gz; rm ucx-1.18.0.tar.gz
cd ucx-1.18.0
# Checking Mellanox NICs
MLX_OPTS=""
if lspci | grep -i mellanox > /dev/null || command -v ibstat > /dev/null; then
echo "Mellanox NIC detected, adding Mellanox-specific options"
MLX_OPTS="--with-rdmacm \
--with-mlx5-dv \
--with-ib-hw-tm"
fi
./configure --prefix=$UCX_HOME \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
--enable-optimizations \
--enable-cma \
--enable-devel-headers \
--with-cuda=$CUDA_HOME \
--with-dm \
--with-gdrcopy=$GDR_HOME \
--with-verbs \
--enable-mt \
$MLX_OPTS
make -j
make -j install-strip
if $SUDO; then
echo "Running ldconfig with sudo"
sudo ldconfig
else
echo "Skipping ldconfig - sudo not available"
echo "Please run 'sudo ldconfig' manually if needed"
fi
cd ..
else
echo "Found existing UCX. Skipping UCX installation"
fi
if ! command -v nixl_test &> /dev/null || [ "$FORCE" = true ]; then
echo "Installing NIXL"
wget https://github.com/ai-dynamo/nixl/archive/refs/tags/0.2.0.tar.gz
tar xzf 0.2.0.tar.gz; rm 0.2.0.tar.gz
cd nixl-0.2.0
meson setup build --prefix=$NIXL_HOME -Ducx_path=$UCX_HOME
cd build
ninja
ninja install
cd ../..
else
echo "Found existing NIXL. Skipping NIXL installation"
fi
...@@ -24,7 +24,7 @@ if printf '%s\n' "${FILES[@]}" | grep -q "^docker/Dockerfile$"; then ...@@ -24,7 +24,7 @@ if printf '%s\n' "${FILES[@]}" | grep -q "^docker/Dockerfile$"; then
fi fi
# Define the target file path # Define the target file path
TARGET_GRAPH_FILE="docs/source/assets/contributing/dockerfile-stages-dependency.png" TARGET_GRAPH_FILE="docs/assets/contributing/dockerfile-stages-dependency.png"
# Ensure target directory exists # Ensure target directory exists
mkdir -p "$(dirname "$TARGET_GRAPH_FILE")" mkdir -p "$(dirname "$TARGET_GRAPH_FILE")"
......
...@@ -1091,7 +1091,6 @@ def scaled_fp4_experts_quant( ...@@ -1091,7 +1091,6 @@ def scaled_fp4_experts_quant(
blockscale_offsets: torch.Tensor, blockscale_offsets: torch.Tensor,
topk: int, topk: int,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
MAX_TOKENS_PER_EXPERT: int = 163840,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Quantize input tensor to FP4 and return quantized tensor and scale, for Quantize input tensor to FP4 and return quantized tensor and scale, for
...@@ -1113,9 +1112,16 @@ def scaled_fp4_experts_quant( ...@@ -1113,9 +1112,16 @@ def scaled_fp4_experts_quant(
input_tensor = input_tensor[ input_tensor = input_tensor[
expert_map] if expert_map is not None else input_tensor expert_map] if expert_map is not None else input_tensor
m_numtopk, k = input_tensor.shape m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for" f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}") f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.")
scales_k = k // 16 scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4 padded_k = (scales_k + (4 - 1)) // 4
......
...@@ -861,7 +861,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -861,7 +861,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
use_custom = use_rocm_custom_paged_attention( use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio, decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window) decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
use_custom = False use_custom = False
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
......
...@@ -283,7 +283,8 @@ def chunked_prefill_paged_decode( ...@@ -283,7 +283,8 @@ def chunked_prefill_paged_decode(
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
block_size, block_size,
num_queries_per_kv, num_queries_per_kv,
max_seq_len, sliding_window) max_seq_len, sliding_window,
kv_cache_dtype, alibi_slopes)
if use_custom: if use_custom:
_PARTITION_SIZE_ROCM = 256 _PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
......
...@@ -31,8 +31,8 @@ def apply_softcap(S, x): ...@@ -31,8 +31,8 @@ def apply_softcap(S, x):
def kernel_unified_attention_2d( def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size] output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs] seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads] alibi_slopes_ptr, # [num_query_heads]
......
...@@ -13,7 +13,6 @@ generation. Supported dataset types include: ...@@ -13,7 +13,6 @@ generation. Supported dataset types include:
TODO: Implement CustomDataset to parse a JSON file and convert its contents into TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT. SampleRequest instances, similar to the approach used in ShareGPT.
""" """
import base64 import base64
import io import io
import json import json
...@@ -33,6 +32,7 @@ from transformers import PreTrainedTokenizerBase ...@@ -33,6 +32,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -129,16 +129,17 @@ class BenchmarkDataset(ABC): ...@@ -129,16 +129,17 @@ class BenchmarkDataset(ABC):
Args: Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected. max_loras (Optional[int]): The maximum number of LoRA is selected.
LoRAs available. If None, LoRA is not used. lora_path max_loras (Optional[int]): The maximum number of LoRAs available.
(Optional[str]): Path to the LoRA parameters on disk. If None, LoRA If `None`, LoRA is not used.
is not used. lora_path (Optional[str]): Path to the LoRA parameters on disk.
If `None`, LoRA is not used.
Returns: Returns:
tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first A tuple with the following elements:
element is a LoRARequest (or None if not applicable) and the second - A new [LoRARequest][] (or `None` if not applicable).
element is the tokenizer associated with the LoRA request (or the - The tokenizer associated with the LoRA request
base tokenizer). (or the base tokenizer).
""" """
if max_loras is None or lora_path is None: if max_loras is None or lora_path is None:
return None, tokenizer return None, tokenizer
...@@ -167,7 +168,7 @@ class BenchmarkDataset(ABC): ...@@ -167,7 +168,7 @@ class BenchmarkDataset(ABC):
Args: Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text. for processing the dataset's text.
num_requests (int): The number of sample requests to generate. num_requests (int): The number of sample requests to generate.
Returns: Returns:
...@@ -184,7 +185,8 @@ class BenchmarkDataset(ABC): ...@@ -184,7 +185,8 @@ class BenchmarkDataset(ABC):
Args: Args:
requests (List[SampleRequest]): The current list of sampled requests (List[SampleRequest]): The current list of sampled
requests. num_requests (int): The target number of requests. requests.
num_requests (int): The target number of requests.
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
...@@ -259,7 +261,7 @@ def process_image(image: Any) -> Mapping[str, Any]: ...@@ -259,7 +261,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, dict) and 'bytes' in image: if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes'])) image = Image.open(BytesIO(image['bytes']))
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = image.convert("RGB") image = convert_image_mode(image, "RGB")
with io.BytesIO() as image_data: with io.BytesIO() as image_data:
image.save(image_data, format="JPEG") image.save(image_data, format="JPEG")
image_base64 = base64.b64encode( image_base64 = base64.b64encode(
......
...@@ -80,6 +80,9 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -80,6 +80,9 @@ def add_cli_args(parser: argparse.ArgumentParser):
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# V1 enables prefix caching by default which skews the latency
# numbers. We need to disable prefix caching by default.
parser.set_defaults(enable_prefix_caching=True)
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
......
...@@ -815,4 +815,4 @@ def main(): ...@@ -815,4 +815,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
...@@ -6,9 +6,7 @@ import os ...@@ -6,9 +6,7 @@ import os
import pprint import pprint
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import ExitStack
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from unittest.mock import patch
import torch import torch
import torch.fx as fx import torch.fx as fx
...@@ -16,13 +14,13 @@ import torch.fx as fx ...@@ -16,13 +14,13 @@ import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationConfig, VllmConfig from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname
from .compiler_interface import (CompilerInterface, EagerAdaptor, from .compiler_interface import (CompilerInterface, EagerAdaptor,
InductorAdaptor, InductorStandaloneAdaptor) InductorAdaptor, InductorStandaloneAdaptor)
from .counter import compilation_counter from .counter import compilation_counter
from .inductor_pass import InductorPass from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile
from .pass_manager import PostGradPassManager from .pass_manager import PostGradPassManager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -297,7 +295,9 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -297,7 +295,9 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
runtime_shape=None) runtime_shape=None)
self.module.__dict__[target] = PiecewiseBackend( piecewise_backend = resolve_obj_by_qualname(
current_platform.get_piecewise_backend_cls())
self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index, submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices, len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape, self.vllm_backend) compiled_graph_for_general_shape, self.vllm_backend)
...@@ -341,7 +341,7 @@ class VllmBackend: ...@@ -341,7 +341,7 @@ class VllmBackend:
): ):
global global_graph_pool global global_graph_pool
if global_graph_pool is None: if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle() global_graph_pool = current_platform.graph_pool_handle()
# TODO: in the future, if we want to use multiple # TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool. # streams, it might not be safe to share a global pool.
...@@ -558,197 +558,3 @@ class VllmBackend: ...@@ -558,197 +558,3 @@ class VllmBackend:
return self.split_gm(*list_args) return self.split_gm(*list_args)
return copy_and_call return copy_and_call
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in cudagraph_capture_sizes
compiled: bool = False
runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
"""
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1)
self.compile_sizes: set[int] = set(
self.compilation_config.compile_sizes)
self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
)
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape]
if entry.runnable is None:
entry.runnable = self.compiled_graph_for_general_shape
if entry.need_to_compile and not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
if not entry.use_cudagraph:
return entry.runnable(*args)
if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)
if self.is_first_graph:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape)
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_caputured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Protocol
import torch.fx as fx
from vllm.compilation.backends import VllmBackend
from vllm.config import VllmConfig
class AbstractPiecewiseBackend(Protocol):
"""
PiecewiseBackend interface that allows platforms to extend
piecewise static graph.
"""
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend, **kwargs):
"""
Initializes the PiecewiseBackend class with compilation and
execution-related configurations.
This class handles piecewise compilation, graph capturing,
and dispatching for specific input shapes.
Args:
graph (fx.GraphModule): The graph represented in fx.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
piecewise_compile_index (int):
Index of the current piecewise subgraph.
total_piecewise_compiles (int):
Total number of piecewise-compiled graphs.
sym_shape_indices (list[int]):
Indices of symbolic shape.
compiled_graph_for_general_shape (Callable):
Callable that executes the graph compiled for general shapes.
vllm_backend (VllmBackend):
Backend compiler that manages compilation and graph runtime
for vLLM.
Keyword Args:
kwargs: Additional keyword arguments reserved for future
extensions or custom platforms.
"""
raise NotImplementedError
def __call__(self, *args) -> Any:
"""Executes the compiled graph for given input args.
If this is the first invocation, executes the general compiled graph
and initiates the compilation process tracking. For subsequent calls,
dynamically dispatches execution to either a compiled graph or a static
graph based on the input shape.
Args:
*args: Variable length input arguments to be passed into the
graph. The symbolic shape is expected to be in position
`sym_shape_indices[0]`.
Returns:
Any: Output of the executed graph. This can be from the general
compiled graph, a specialized compiled version for the given shape,
or a replayed static graph.
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str):
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class GEMMReduceScatterPattern(BasePattern):
def get_inputs(self):
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [mul, mm_weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
mm = torch.ops.aten.mm.default(mul, mm_weight)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name)
return reduce_scatter
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
return gemm_rs
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AllGatherGEMMPattern(BasePattern):
def get_inputs(self):
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [x, weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name)
return torch.ops.aten.mm.default(all_gather, weight)
def replacement(
x: torch.Tensor,
weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
x,
[weight],
gather_dim=0,
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AsyncTPPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
super().__init__(config)
# Enable symmetric memory for the TP process group
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="async_tp_pass")
GEMMReduceScatterPattern(self.model_dtype,
self.device).register(self.patterns)
AllGatherGEMMPattern(self.model_dtype,
self.device).register(self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_async_tp_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_async_tp_pass")
self.end_and_log()
...@@ -39,7 +39,8 @@ class CompilerInterface: ...@@ -39,7 +39,8 @@ class CompilerInterface:
Gather all the relevant information from the vLLM config, Gather all the relevant information from the vLLM config,
to compute a hash so that we can cache the compiled model. to compute a hash so that we can cache the compiled model.
See {meth}`VllmConfig.compute_hash` to check what information See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
to check what information
is already considered by default. This function should only is already considered by default. This function should only
consider the information that is specific to the compiler. consider the information that is specific to the compiler.
""" """
......
# SPDX-License-Identifier: Apache-2.0
import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
logger = init_logger(__name__)
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in cudagraph_capture_sizes
compiled: bool = False
runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
class CUDAPiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
"""
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1)
self.compile_sizes: set[int] = set(
self.compilation_config.compile_sizes)
self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
)
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape]
if entry.runnable is None:
entry.runnable = self.compiled_graph_for_general_shape
if entry.need_to_compile and not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
if not entry.use_cudagraph:
return entry.runnable(*args)
if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)
if self.is_first_graph:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape)
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_caputured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output
...@@ -6,6 +6,7 @@ from vllm.config import VllmConfig ...@@ -6,6 +6,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .activation_quant_fusion import ActivationQuantFusionPass from .activation_quant_fusion import ActivationQuantFusionPass
from .collective_fusion import AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass from .fusion import FusionPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
...@@ -54,6 +55,8 @@ class PostGradPassManager(CustomGraphPass): ...@@ -54,6 +55,8 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_sequence_parallelism: if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)] self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
......
...@@ -243,24 +243,25 @@ class SequenceParallelismPass(VllmInductorPass): ...@@ -243,24 +243,25 @@ class SequenceParallelismPass(VllmInductorPass):
pass_name="sequence_parallelism_pass") pass_name="sequence_parallelism_pass")
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
EmbeddingAllReduceRMSNormPattern( EmbeddingAllReduceRMSNormPattern(
epsilon, self.dtype, self.device).register(self.patterns) epsilon, self.model_dtype, self.device).register(self.patterns)
MiddleAllReduceRMSNormPattern(epsilon, self.dtype, MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns) self.device).register(self.patterns)
LastAllReduceRMSNormPattern(epsilon, self.dtype, LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns) self.device).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache # WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon. # and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear() torch._inductor.pattern_matcher._seen_patterns.clear()
def is_applicable_for_shape(self, shape: Optional[int]) -> bool: def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0 return shape is not None and shape % tp_size == 0
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_sequence_parallelism_pass") self.dump_graph(graph, "before_sequence_parallelism_pass")
count = self.patterns.apply(graph) count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count) logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_sequence_parallelism_pass") self.dump_graph(graph, "after_sequence_parallelism_pass")
self.end_and_log()
...@@ -26,7 +26,8 @@ class VllmInductorPass(InductorPass): ...@@ -26,7 +26,8 @@ class VllmInductorPass(InductorPass):
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
self.dtype = config.model_config.dtype if config.model_config else None self.model_dtype = config.model_config.dtype if config.model_config \
else None
self.device = config.device_config.device if config.device_config \ self.device = config.device_config.device if config.device_config \
else None else None
self.pass_name = self.__class__.__name__ self.pass_name = self.__class__.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment