Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
......@@ -3,14 +3,16 @@
import argparse
import dataclasses
import json
import os
import random
import time
from functools import cache
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import uvloop
from benchmark_utils import convert_to_pytorch_benchmark_format
from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
......@@ -361,6 +363,25 @@ def run_mii(
return end - start
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: Dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"requests_per_second": [results["requests_per_second"]],
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
k: results[k]
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
})
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
with open(pt_file, "w") as f:
json.dump(pt_records, f)
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
......@@ -459,6 +480,7 @@ def main(args: argparse.Namespace):
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
if __name__ == "__main__":
......
# SPDX-License-Identifier: Apache-2.0
import ast
import copy
import dataclasses
import os
import pprint
import time
from collections import defaultdict
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch
......@@ -19,6 +17,7 @@ from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
from .compiler_interface import EagerAdaptor, InductorAdaptor
from .counter import compilation_counter
from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile
......@@ -27,306 +26,128 @@ from .pass_manager import PostGradPassManager
logger = init_logger(__name__)
@dataclasses.dataclass
class InductorArtifact:
hash_str: str = ""
file_path: str = ""
class CompilerManager:
"""
A manager to manage the compilation process, including
caching the compiled graph, loading the compiled graph,
and compiling the graph.
The cache is a dict mapping
`(runtime_shape, graph_index, backend_name)`
to `any_data` returned from the compiler.
class InductorHashCache:
When serializing the cache, we save it to a Python file
for readability. We don't use json here because json doesn't
support int as key.
"""
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str, file_path)
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.
def __init__(self, use_inductor: bool):
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
cls = InductorAdaptor if use_inductor else EagerAdaptor
self.compiler = cls()
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
we don't use json here because json doesn't support int as key.
TODO: better off-the-shelf solution to serialize the data?
"""
def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)
def __init__(self, cache_dir: str, disabled: bool = False):
self.cache: Dict[Optional[int],
Dict[int, InductorArtifact]] = defaultdict(dict)
self.disabled = disabled
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
self.disable_cache = disable_cache
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir,
"inductor_hash_cache.py")
if disabled:
return
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
if os.path.exists(self.cache_file_path):
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
if not disable_cache and os.path.exists(self.cache_file_path):
# load the cache from the file
with open(self.cache_file_path) as f:
self.deserialize(f.read())
def deserialize(self, data: str):
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
for item in list_data:
runtime_shape = item[0]
graph_index = item[1]
hash_str = item[2]
# for compatibility of old version,
# where we don't have file_path.
# NOTE: after running the new code, the file_path
# will be updated.
file_path = "" if len(item) == 3 else item[3]
self.cache[runtime_shape][graph_index] = InductorArtifact(
hash_str=hash_str, file_path=file_path)
def serialize(self) -> str:
data = []
for runtime_shape, value in self.cache.items():
for graph_index, inductor_artifact in value.items():
data.append(
(runtime_shape, graph_index, inductor_artifact.hash_str,
inductor_artifact.file_path))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
self.cache = ast.literal_eval(f.read())
self.compiler.initialize_cache(cache_dir=cache_dir,
disable_cache=disable_cache)
def save_to_file(self):
if self.disabled:
if self.disable_cache:
return
with open(self.cache_file_path, "w") as f:
f.write(self.serialize())
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
if self.disabled:
return False
runtime_shape, graph_index = key
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
if self.disabled:
raise KeyError("cannot read from disabled cache")
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]
def __setitem__(self, key: Tuple[Optional[int], int],
value: InductorArtifact):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
self.cache[runtime_shape][graph_index] = value
class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def __init__(self) -> None:
self.guards: List[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
return True
def get_pruned_guards(self, *args, **kwargs):
return []
def produce_guards_expression(self, *args, **kwargs):
return ""
def wrap_inductor(graph: fx.GraphModule,
example_inputs,
additional_inductor_config,
compilation_config: CompilationConfig,
vllm_backend: "VllmBackend",
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None,
use_inductor: bool = True) -> Any:
if graph_index == 0:
# before compiling the first graph, record the start time
global compilation_start_time
compilation_start_time = time.time()
if not use_inductor:
return graph
compilation_counter.num_inductor_compilations += 1
from torch._inductor import config
current_config = config.get_config_copy()
from torch._inductor.compile_fx import compile_fx
if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config["max_autotune"] = True
current_config["coordinate_descent_tuning"] = True
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
cache_data = vllm_backend.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
inductor_artifact = cache_data[(runtime_shape, graph_index)]
hash_str = inductor_artifact.hash_str
if graph_index == 0:
# adds some info logging for the first graph
logger.info(
"Directly lookup the graph for shape %s from the cache",
str(runtime_shape)) # noqa
printer = pprint.PrettyPrinter(indent=4)
data = printer.pformat(self.cache)
f.write(data)
def load(self,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Optional[Callable]:
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
return None
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load(handle, graph, example_inputs,
graph_index, runtime_shape)
logger.debug(
"directly lookup the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), hash_str)
from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache file {cache_data.cache_file_path} and try again." # noqa
)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
else:
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
inductor_artifact = InductorArtifact()
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
original_load = FxGraphCache.load
def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
return inductor_compiled_graph
def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
inductor_artifact.hash_str = out[0]
return out
def _check_can_cache(*args, **kwargs):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv()
with ExitStack() as stack:
if not cache_data.disabled:
# compilation cache is enabled, patch several functions
# hijack to get the compiled graph itself
stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache.load",
hijack_load))
# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash))
# for providing a dummy shape environment
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache))
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)
# store the inductor_artifact in the cache
cache_data[(runtime_shape, graph_index)] = inductor_artifact
"Directly load the %s-th graph for shape %s from %s via "
"handle %s", graph_index, str(runtime_shape), self.compiler.name,
handle)
return compiled_graph
def compile(self,
graph: fx.GraphModule,
example_inputs,
additional_inductor_config,
compilation_config: CompilationConfig,
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None) -> Any:
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug(
"store the %s-th graph for shape %s via hash %s from file %s",
graph_index, str(runtime_shape), inductor_artifact.hash_str,
inductor_artifact.file_path)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info("Compiling a graph for general shape takes %.2f s",
elapsed)
else:
logger.info("Compiling a graph for shape %s takes %.2f s",
runtime_shape, elapsed)
# before compiling the first graph, record the start time
global compilation_start_time
compilation_start_time = time.time()
compilation_counter.num_backend_compilations += 1
compiled_graph = None
# try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index,
runtime_shape)
if compiled_graph is not None:
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Directly load the compiled graph for shape %s "
"from the cache", str(runtime_shape)) # noqa
return compiled_graph
# no compiler cached the graph, or the cache is disabled,
# we need to compile it
compiled_graph, handle = self.compiler.compile(
graph, example_inputs, additional_inductor_config, runtime_shape)
assert compiled_graph is not None, "Failed to compile the graph"
# store the artifact in the cache
if handle is not None:
self.cache[(runtime_shape, graph_index,
self.compiler.name)] = handle
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug(
"store the %s-th graph for shape %s from %s via handle %s",
graph_index, str(runtime_shape), self.compiler.name, handle)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info("Compiling a graph for general shape takes %.2f s",
elapsed)
else:
logger.info("Compiling a graph for shape %s takes %.2f s",
runtime_shape, elapsed)
return compiled_graph
return compiled_graph
@dataclasses.dataclass
......@@ -436,16 +257,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time
compiled_graph_for_general_shape = wrap_inductor(
compiled_graph_for_general_shape = self.vllm_backend.\
compiler_manager.compile(
submod,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
self.vllm_backend,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
use_inductor=self.compilation_config.use_inductor)
runtime_shape=None)
self.module.__dict__[target] = PiecewiseBackend(
submod, self.vllm_config, self.graph_pool, index,
......@@ -483,7 +303,7 @@ class VllmBackend:
post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]
inductor_hash_cache: InductorHashCache
compiler_manager: CompilerManager
def __init__(
self,
......@@ -507,6 +327,9 @@ class VllmBackend:
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.compiler_manager: CompilerManager = CompilerManager(
self.compilation_config.use_inductor)
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
......@@ -533,9 +356,11 @@ class VllmBackend:
# the cache dir will be the same so that we can reuse the compiled
# graph.
factors = []
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
config_hash = vllm_config.compute_hash()
factors.append(config_hash)
# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
......@@ -553,10 +378,15 @@ class VllmBackend:
import hashlib
code_hash = hashlib.md5(
"\n".join(hash_content).encode()).hexdigest()
factors.append(code_hash)
# 3. compiler hash
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
factors.append(compiler_hash)
# combine all factors to generate the cache dir
hash_key = hashlib.md5(str(factors).encode()).hexdigest()[:10]
# combine the two hashes to generate the cache dir
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_compile_cache",
......@@ -570,15 +400,16 @@ class VllmBackend:
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
self.compilation_config.local_cache_dir = local_cache_dir
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
local_cache_dir, disabled=disabled)
if disabled:
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
if disable_cache:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info("Using cache directory: %s for vLLM's torch.compile",
local_cache_dir)
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter.num_graphs_seen += 1
......@@ -759,7 +590,7 @@ class PiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.inductor_hash_cache.save_to_file()
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any:
......@@ -782,16 +613,14 @@ class PiecewiseBackend:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = wrap_inductor(
entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
self.vllm_backend,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
use_inductor=self.compilation_config.use_inductor)
runtime_shape=runtime_shape)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
......
# SPDX-License-Identifier: Apache-2.0
import copy
import hashlib
import os
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
import torch
import torch._inductor.compile_fx
import torch.fx as fx
from vllm.config import VllmConfig
class CompilerInterface:
"""
The interface for a compiler that can be used by vLLM.
"""
# The name of the compiler, e.g. inductor.
# This is a class-level attribute.
name: str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
"""
pass
def compute_hash(self, vllm_config: VllmConfig) -> str:
"""
Gather all the relevant information from the VLLM config,
to compute a hash so that we can cache the compiled model.
See :meth:`VllmConfig.compute_hash` to check what information
is already considered by default. This function should only
consider the information that is specific to the compiler.
"""
return ""
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
) -> Tuple[Optional[Callable], Optional[Any]]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
the `example_inputs` have a dynamic shape. Otherwise, the
`runtime_shape` specifies the shape of the inputs. Right now we only
support one variable shape for all inputs, which is the batchsize
(number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
The function should return a compiled callable function, as well as
a handle that can be used to directly load the compiled function.
The handle should be a plain Python object, preferably a string or a
file path for readability.
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
"""
return None, None
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
The handle is the second return value of the `compile` function.
"""
raise NotImplementedError("caching is not supported")
class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def __init__(self) -> None:
self.guards: List[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
return True
def get_pruned_guards(self, *args, **kwargs):
return []
def produce_guards_expression(self, *args, **kwargs):
return ""
class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5 and 2.6.
"""
name = "inductor"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
if disable_cache:
return
# redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
) -> Tuple[Optional[Callable], Optional[Any]]:
from torch._inductor import config
current_config = config.get_config_copy()
from torch._inductor.compile_fx import compile_fx
# disable remote cache
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
if compiler_config is not None:
current_config.update(compiler_config)
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config["max_autotune"] = True
current_config["coordinate_descent_tuning"] = True
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
hash_str, file_path = None, None
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
return inductor_compiled_graph
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
elif torch.__version__ >= "2.6":
# function renamed in 2.6
original_load_name = None
def hijacked_compile_fx_inner(*args, **kwargs):
output = torch._inductor.compile_fx.compile_fx_inner(
*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
if inductor_compiled_graph is not None:
nonlocal file_path
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str
hash_str = out[0]
return out
def _check_can_cache(*args, **kwargs):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv()
with ExitStack() as stack:
# hijack to get the compiled graph itself
if original_load_name is not None:
stack.enter_context(patch(original_load_name, hijack_load))
# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash))
# for providing a dummy shape environment
stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache))
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config)
assert hash_str is not None, (
"failed to get the hash of the compiled graph")
assert file_path is not None, (
"failed to get the file path of the compiled graph")
return compiled_graph, (hash_str, file_path)
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
if torch.__version__.startswith("2.5"):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache directory and try again." # noqa
)
elif torch.__version__ >= "2.6":
from torch._inductor.output_code import (
CompiledFxGraphConstantsWithGm)
constants = CompiledFxGraphConstantsWithGm(graph)
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, None, constants)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache directory and try again." # noqa
)
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph
class EagerAdaptor(CompilerInterface):
name = "eager"
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
) -> Tuple[Optional[Callable], Optional[Any]]:
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
return graph, None
......@@ -13,7 +13,7 @@ class CompilationCounter:
num_piecewise_graphs_seen: int = 0
# not including the splitting ops
num_piecewise_capturable_graphs_seen: int = 0
num_inductor_compilations: int = 0
num_backend_compilations: int = 0
num_cudagraph_caputured: int = 0
def clone(self) -> "CompilationCounter":
......
......@@ -13,7 +13,6 @@ from torch import fx
class InductorPass(ABC):
"""
General custom inductor pass interface.
TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass
"""
@abstractmethod
......
......@@ -2,6 +2,7 @@
from typing import Any, Dict, List
import torch
from torch import fx as fx
from vllm.config import CompilationConfig
......@@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass
logger = init_logger(__name__)
class PostGradPassManager:
class PlaceHolder:
pass
if torch.__version__ < "2.6":
Parent = PlaceHolder # type: ignore
else:
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
class PostGradPassManager(Parent):
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
......@@ -55,6 +66,9 @@ class PostGradPassManager:
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
def uuid(self):
return self.__getstate__()
def __getstate__(self) -> Dict[str, List[Any]]:
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
......
......@@ -54,17 +54,18 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward"]
"score", "reward", "transcription"]
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
"draft"]
"draft", "transcription"]
RunnerType = Literal["generate", "pooling", "draft"]
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
......@@ -102,8 +103,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
......@@ -407,7 +409,8 @@ class ModelConfig:
if is_s3(model) or is_s3(tokenizer):
if is_s3(model):
s3_model = S3Model()
s3_model.pull_files(model, allow_pattern=["*config.json"])
s3_model.pull_files(
model, allow_pattern=["*.model", "*.py", "*.json"])
self.model_weights = self.model
self.model = s3_model.dir
......@@ -467,10 +470,10 @@ class ModelConfig:
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]:
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow' or 'mistral'.")
"either 'auto', 'slow', 'mistral' or 'custom'.")
self.tokenizer_mode = tokenizer_mode
def _get_preferred_task(
......@@ -483,6 +486,8 @@ class ModelConfig:
return "embed"
if ModelRegistry.is_cross_encoder_model(architectures):
return "score"
if ModelRegistry.is_transcription_model(architectures):
return "transcription"
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
......@@ -515,6 +520,8 @@ class ModelConfig:
runner_support: Dict[RunnerType, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"transcription":
ModelRegistry.is_transcription_model(architectures),
"generate": ModelRegistry.is_text_generation_model(architectures),
"pooling": ModelRegistry.is_pooling_model(architectures),
}
......@@ -756,7 +763,7 @@ class ModelConfig:
def is_deepseek_mla(self) -> bool:
return (hasattr(self.hf_text_config, "model_type")) \
and (self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3'))\
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
and (self.hf_text_config.kv_lora_rank is not None)
def get_head_size(self) -> int:
......@@ -849,8 +856,12 @@ class ModelConfig:
def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
if self.hf_text_config.model_type == "deepseek_mtp":
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
......@@ -985,37 +996,7 @@ class ModelConfig:
@property
def use_mla(self) -> bool:
if not self.is_deepseek_mla or envs.VLLM_MLA_DISABLE:
return False
if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)
return False
# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if self.quantization == "compressed-tensors":
quant_config = self._parse_quant_hf_config()
for group_name, cfg in quant_config.get("config_groups", {
"": {}
}).items():
act_cfg = cfg.get("input_activations", {})
act_type = None if act_cfg is None else act_cfg.get("type", "")
w_cfg = cfg.get("weights", {})
w_type = None if w_cfg is None else w_cfg.get("type", "")
if act_type != "fp8" or w_type != "fp8":
logger.warning(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.\n "
"Full config: %s", group_name, act_type, w_type,
quant_config)
return False
return True
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
@property
def supported_runner_types(self) -> Set[RunnerType]:
......@@ -1403,6 +1384,9 @@ class ParallelConfig:
logger.info("Defaulting to use %s for distributed inference",
backend)
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
self._verify_args()
@property
......@@ -1453,6 +1437,17 @@ class SchedulerConfig:
# Maximum length of a sequence (including prompt and generated text).
max_model_len: int = 8192
# Maximum number of sequences that can be partially prefilled concurrently
max_num_partial_prefills: int = 1
# Maximum number of "very long prompt" sequences that can be prefilled
# concurrently (long is defined by long_prefill_threshold)
max_long_partial_prefills: int = 1
# calculate context length that determines which sequences are
# considered "long"
long_prefill_token_threshold: int = 0
# The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be
......@@ -1502,6 +1497,10 @@ class SchedulerConfig:
chunked_prefill_enabled: bool = field(init=False)
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......@@ -1560,6 +1559,18 @@ class SchedulerConfig:
self.max_num_batched_tokens)
self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len *
0.04)
logger.info(
"Concurrent partial prefills enabled with "
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
"long_prefill_token_threshold=%d",
self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold)
self._verify_args()
def _verify_args(self) -> None:
......@@ -1591,6 +1602,29 @@ class SchedulerConfig:
f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.")
if self.max_num_partial_prefills < 1:
raise ValueError(
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
"must be greater than or equal to 1.")
elif self.max_num_partial_prefills > 1:
if not self.chunked_prefill_enabled:
raise ValueError("Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1.")
if self.long_prefill_token_threshold > self.max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({self.max_model_len}).")
if (self.max_long_partial_prefills
< 1) or (self.max_long_partial_prefills
> self.max_num_partial_prefills):
raise ValueError(
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1
......@@ -1665,6 +1699,18 @@ class SpeculativeConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
return hf_config
@staticmethod
def maybe_create_spec_config(
target_model_config: ModelConfig,
......@@ -1750,12 +1796,18 @@ class SpeculativeConfig:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""
if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None
if target_model_config.hf_text_config.model_type \
== "deepseek_v3":
# use the draft model from the same model:
speculative_model = target_model_config.model
else:
raise ValueError(
"num_speculative_tokens was provided without "
"speculative_model.")
else:
return None
if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
......@@ -1809,10 +1861,20 @@ class SpeculativeConfig:
max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
draft_hf_config = draft_model_config.hf_config
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
if "eagle-" in draft_model_config.model.lower():
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(draft_model_config.hf_config, EAGLEConfig):
pass
else:
eagle_config = EAGLEConfig(draft_model_config.hf_config)
draft_model_config.hf_config = eagle_config
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
......@@ -1934,8 +1996,9 @@ class SpeculativeConfig:
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1")
"%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type)
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
......@@ -3070,7 +3133,8 @@ class VllmConfig:
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing and debugging.
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
additional_config: SupportsHash = field(default=None,
init=True) # type: ignore
instance_id: str = ""
......@@ -3088,15 +3152,6 @@ class VllmConfig:
the final hidden states.
"""
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
# summarize vllm config
vllm_factors: List[Any] = []
......
......@@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
SequenceStage, SequenceStatus)
from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__)
......@@ -39,6 +39,7 @@ class PreemptionMode(enum.Enum):
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()
......@@ -54,6 +55,7 @@ class SchedulingBudget:
happen if we only have chunked prefill scheduling, we can remove this
feature from the API when chunked prefill is enabled by default.
"""
token_budget: int
max_num_seqs: int
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
......@@ -132,6 +134,7 @@ class ScheduledSequenceGroup:
@dataclass
class SchedulerOutputs:
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
......@@ -205,6 +208,7 @@ class SchedulerRunningOutputs:
Could contain prefill (prefill that's chunked) or decodes. If there's not
enough memory, it can be preempted (for recompute) or swapped out.
"""
# Selected sequences that are running and in a decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are running and in a prefill phase.
......@@ -246,6 +250,7 @@ class SchedulerSwappedInOutputs:
Could contain prefill (prefill that's chunked) or decodes.
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup]
......@@ -280,6 +285,7 @@ class SchedulerPrefillOutputs:
Could contain a fresh prefill requests or preempted requests that need
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups: List[ScheduledSequenceGroup]
# Ignored sequence groups.
......@@ -321,6 +327,100 @@ def scheduled_seq_group_builder():
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
@dataclass
class PartialPrefillMetadata:
"""Holds information about the partial prefills that are currently running
during a single iteration of the Scheduler.
When chunked prefill is enabled, we allow a certain number of seqs to be
partially prefilled during each iteration. Having multiple partial prefills
in flight allows us to minimize TTFT and avoid decode starvation in cases
where a single sequence group with a very large prompt blocks the queue for
too many iterations.
The number of long prefill requests is limited so that smaller
requests may jump the queue in front of them and get to the decode
phase faster.
"""
# A minimum bound on the total number of prefills to be scheduled during
# this iteration
schedulable_prefills: int
# The number of long prefill requests currently running
long_prefills: int
scheduler_config: SchedulerConfig
def can_schedule(self, seq_group: SequenceGroup) -> bool:
"""When concurrent partial prefills are enabled,
we limit the number of long requests and only accept
shorter requests from the queue while running them
concurrently"""
return not (seq_group.first_seq.get_num_new_tokens()
> self.scheduler_config.long_prefill_token_threshold
and self.long_prefills
>= self.scheduler_config.max_long_partial_prefills
and self.scheduler_config.max_num_partial_prefills > 1)
def maybe_increment_partial_prefills(self,
seq_group: SequenceGroup) -> None:
# When a new prefill is scheduled, we need to know if it is a
# long request
if (seq_group.first_seq.get_num_new_tokens()
> self.scheduler_config.long_prefill_token_threshold):
self.long_prefills += 1
@classmethod
def from_queues(
cls,
running: Deque[SequenceGroup],
waiting: Deque[SequenceGroup],
scheduler_config: SchedulerConfig,
) -> "PartialPrefillMetadata":
"""Create a PartialPrefillMetadata object from the current state of
the scheduler's queues.
This accounts for the currently running prefill requests, and peeks into
the waiting queue to see if there are more prefills to potentially be
scheduled during this iteration."""
prefills = 0
long_prefills = 0
waiting_long_prefills = 0
for sg in running:
if sg.first_seq.data.stage == SequenceStage.PREFILL:
prefills += 1
if (sg.first_seq.get_num_new_tokens()
> scheduler_config.long_prefill_token_threshold):
long_prefills += 1
for sg in waiting:
# Don't bother looping through the rest of the queue if we know
# there are already at
# least max_partial_prefills requests to fill
if prefills >= scheduler_config.max_num_partial_prefills:
break
# Don't count long requests from the waiting queue if we aren't
# going to schedule them anyway
if (sg.first_seq.get_num_new_tokens()
> scheduler_config.long_prefill_token_threshold):
if (long_prefills + waiting_long_prefills
>= scheduler_config.max_long_partial_prefills):
continue
waiting_long_prefills += 1
prefills += 1
# NB: long_prefills and waiting_long_prefills are tracked separately.
# We don't account for the waiting requests here because we need to use
# this metadata to track how many have actually been scheduled.
return PartialPrefillMetadata(
schedulable_prefills=min(
prefills, scheduler_config.max_num_partial_prefills),
long_prefills=long_prefills,
scheduler_config=scheduler_config,
)
class Scheduler:
def __init__(
......@@ -360,7 +460,8 @@ class Scheduler:
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
enable_caching=self.cache_config.enable_prefix_caching,
)
# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
......@@ -421,6 +522,18 @@ class Scheduler:
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []
# List with the chunk sizes to hand out to each sequence depending
# on how many partial prefills are running. This is slightly faster than
# running an integer division every time a prefill is scheduled.
# This splits the budget evenly among all prefills.
self.partial_prefill_budget_lookup_list = [0] * (
self.scheduler_config.max_num_partial_prefills + 1)
self.partial_prefill_budget_lookup_list[0] = (
scheduler_config.max_num_batched_tokens)
for i in range(1, self.scheduler_config.max_num_partial_prefills + 1):
self.partial_prefill_budget_lookup_list[i] = (
scheduler_config.max_num_batched_tokens // i)
@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
......@@ -500,8 +613,8 @@ class Scheduler:
self.block_manager.free_cross(seq_group)
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
return (len(self.waiting) != 0 or len(self.running) != 0
or len(self.swapped) != 0)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
......@@ -523,6 +636,7 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
......@@ -537,12 +651,14 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
SchedulerRunningOutputs.
"""
ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache[self.cache_id].get_object()
ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[
self.cache_id].get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
......@@ -577,10 +693,14 @@ class Scheduler:
# 2. If a sequence is running with non-chunked prefill, then
# there it's a decoding sequence, and the cached tokens info is
# irrelevant.
num_uncached_new_tokens, _ = (
num_uncached_new_tokens, _ = \
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking,
budget))
seq_group,
SequenceStatus.RUNNING,
enable_chunking,
budget,
partial_prefill_metadata,
)
num_running_tokens = num_uncached_new_tokens
if num_running_tokens == 0:
......@@ -593,8 +713,8 @@ class Scheduler:
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if self.use_async_output_proc and seq_group.seqs[0].get_len(
) > self.scheduler_config.max_model_len:
if (self.use_async_output_proc and seq_group.seqs[0].get_len()
> self.scheduler_config.max_model_len):
self._async_stopped.append(seq_group)
continue
......@@ -653,8 +773,9 @@ class Scheduler:
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache[self.cache_id].get_object()
scheduled_seq_group: ScheduledSequenceGroup = (
self._scheduled_seq_group_cache[
self.cache_id].get_object())
scheduled_seq_group.seq_group = seq_group
if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens
......@@ -731,7 +852,8 @@ class Scheduler:
logger.warning(
"Failing the request %s because there's not enough kv "
"cache blocks to run the entire sequence.",
seq_group.request_id)
seq_group.request_id,
)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
infeasible_seq_groups.append(seq_group)
......@@ -770,7 +892,6 @@ class Scheduler:
swapped_queue.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
is_prefill = seq_group.is_prefill()
if is_prefill:
prefill_seq_groups.append(
ScheduledSequenceGroup(
......@@ -801,16 +922,17 @@ class Scheduler:
)
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
if self.scheduler_config.chunked_prefill_enabled and \
not self.scheduler_config.is_multi_step:
if (self.scheduler_config.chunked_prefill_enabled
and not self.scheduler_config.is_multi_step):
prompt_limit = self.scheduler_config.max_model_len
else:
prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens,
)
# Model is fine tuned with long context. Return the fine tuned max_len.
if (seq_group.lora_request
and seq_group.lora_request.long_lora_max_len):
if seq_group.lora_request and seq_group.lora_request.long_lora_max_len:
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
return seq_group.lora_request.long_lora_max_len
else:
......@@ -818,7 +940,7 @@ class Scheduler:
def _get_priority(self,
seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
""" Get the priority of the sequence group.
"""Get the priority of the sequence group.
Highest preference to user-defined priority, followed by arrival time.
Args:
seq_group: The sequence group input.
......@@ -851,14 +973,14 @@ class Scheduler:
if waiting_queue:
seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens_uncached, _ = (
num_new_tokens_uncached, _ = \
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, False, budget))
seq_group, SequenceStatus.WAITING, False, budget)
#Only preempt if priority inversion exists
# Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated
# Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens_uncached > 0
and can_allocate == AllocStatus.OK
......@@ -868,7 +990,7 @@ class Scheduler:
)):
break
#Adjust budget to remove the victim sequence group
# Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
num_running_tokens_uncached, _ = (
self._get_num_new_uncached_and_cached_tokens(
......@@ -879,11 +1001,11 @@ class Scheduler:
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)
#Preempt out the victim sequence group
# Preempt out the victim sequence group
self._preempt(vseq_group, blocks_to_swap_out)
waiting_queue.appendleft(vseq_group)
force_preemption_count += 1
#Put the sequence back into the waiting queue
# Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group)
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
......@@ -897,6 +1019,7 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage.
......@@ -917,10 +1040,20 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
SchedulerPrefillOutputs.
"""
if budget.remaining_token_budget() == 0:
# Do nothing: Can't add any more prefill anyway
return SchedulerPrefillOutputs(
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking),
)
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []
......@@ -934,10 +1067,19 @@ class Scheduler:
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
if (partial_prefill_metadata is not None
and not partial_prefill_metadata.can_schedule(seq_group)):
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
num_new_tokens_uncached, num_new_tokens_cached = (
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, enable_chunking,
budget))
seq_group,
SequenceStatus.WAITING,
enable_chunking,
budget,
partial_prefill_metadata=partial_prefill_metadata,
))
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
if not enable_chunking:
......@@ -948,7 +1090,10 @@ class Scheduler:
if num_new_tokens > prompt_limit:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens, prompt_limit)
" and exceeds limit of %d",
num_new_tokens,
prompt_limit,
)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
......@@ -969,7 +1114,9 @@ class Scheduler:
logger.warning(
"Input prompt (%d tokens) + lookahead slots (%d) is "
"too long and exceeds the capacity of block_manager",
num_new_tokens, num_lookahead_slots)
num_new_tokens,
num_lookahead_slots,
)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
......@@ -1010,6 +1157,10 @@ class Scheduler:
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
if partial_prefill_metadata is not None:
partial_prefill_metadata.maybe_increment_partial_prefills(
seq_group)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
# init_multi_step_from_lookahead_slots happens in append_slots
......@@ -1025,7 +1176,8 @@ class Scheduler:
num_scheduler_steps=self.scheduler_config.
num_scheduler_steps,
is_multi_step=self.scheduler_config.is_multi_step,
enable_chunking=enable_chunking)
enable_chunking=enable_chunking,
)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
......@@ -1046,11 +1198,12 @@ class Scheduler:
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
is_prefill=True, enable_chunking=enable_chunking),
)
def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
......@@ -1066,9 +1219,9 @@ class Scheduler:
for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
curr_loras = set(
curr_loras = (set(
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None
if seq_group.lora_int_id > 0) if self.lora_enabled else None)
prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty()
......@@ -1094,9 +1247,10 @@ class Scheduler:
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
if (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out) == 0):
swapped_in = \
self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
......@@ -1116,8 +1270,8 @@ class Scheduler:
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))
preempted = len(running_scheduled.preempted) + len(
running_scheduled.swapped_out)
# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
......@@ -1155,7 +1309,7 @@ class Scheduler:
def _schedule_chunked_prefill(self) -> SchedulerOutputs:
"""Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not
......@@ -1176,10 +1330,20 @@ class Scheduler:
prefills = SchedulerPrefillOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# Create partial prefill metadata
partial_prefill_metadata = PartialPrefillMetadata.from_queues(
running=self.running,
waiting=self.waiting,
scheduler_config=self.scheduler_config,
)
# Decoding should be always scheduled first by fcfs.
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=True)
running_scheduled = self._schedule_running(
budget,
curr_loras,
enable_chunking=True,
partial_prefill_metadata=partial_prefill_metadata,
)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
......@@ -1187,9 +1351,12 @@ class Scheduler:
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=True)
prefills = self._schedule_prefills(
budget,
curr_loras,
enable_chunking=True,
partial_prefill_metadata=partial_prefill_metadata,
)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
......@@ -1208,8 +1375,15 @@ class Scheduler:
[s.seq_group for s in swapped_in.prefill_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
# Because multiple prefills may be running concurrently, we need to
# make sure that prefills which are scheduled to finish are listed
# before those that won't. This is so that on the next scheduling
# iteration when they have transitioned to the decode stage, they are
# properly prioritized over sequences that are still in the prefill
# stage.
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
self._order_finishing_prefills_first(
running_scheduled.prefill_seq_groups))
self.running.extend([s.seq_group for s in prefills.seq_groups])
# Update swapped requests.
......@@ -1226,7 +1400,7 @@ class Scheduler:
# If all prompts, then we set num_lookahead_slots to 0
# this allows us to go through the `no_spec` path in
# `spec_decode_worker.py`
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
all_prefills = len(scheduled_seq_groups) == num_prefill_groups
num_lookahead_slots = (0 if
(all_prefills
and not self.scheduler_config.is_multi_step)
......@@ -1248,6 +1422,21 @@ class Scheduler:
len(running_scheduled.swapped_out)),
)
def _order_finishing_prefills_first(
self, scheduled_prefill_seqs: List[ScheduledSequenceGroup]
) -> List[SequenceGroup]:
"""Returns a list of prefilling SequenceGroups where sequences that are
scheduled to finish prefilling are listed first"""
finishing = [
s.seq_group for s in scheduled_prefill_seqs
if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size
]
not_finishing = [
s.seq_group for s in scheduled_prefill_seqs
if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size
]
return finishing + not_finishing
def _schedule(self) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
......@@ -1386,10 +1575,12 @@ class Scheduler:
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
multi_modal_placeholders=seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None,
multi_modal_data=(seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups
> 0 else None),
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
......@@ -1495,10 +1686,12 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
def _append_slots(self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
enable_chunking: bool = False) -> None:
def _append_slots(
self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
enable_chunking: bool = False,
) -> None:
"""Appends new slots to the sequences in the given sequence group.
Args:
......@@ -1519,7 +1712,8 @@ class Scheduler:
num_lookahead_slots,
num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
is_multi_step=self.scheduler_config.is_multi_step,
enable_chunking=enable_chunking)
enable_chunking=enable_chunking,
)
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
if self.scheduler_config.is_multi_step and enable_chunking:
......@@ -1562,8 +1756,11 @@ class Scheduler:
"not enough KV cache space. This can affect the end-to-end "
"performance. Increase gpu_memory_utilization or "
"tensor_parallel_size to provide more KV cache memory. "
"total_num_cumulative_preemption=%d", seq_group.request_id,
preemption_mode, self.num_cumulative_preemption + 1)
"total_num_cumulative_preemption=%d",
seq_group.request_id,
preemption_mode,
self.num_cumulative_preemption + 1,
)
self.num_cumulative_preemption += 1
if preemption_mode == PreemptionMode.RECOMPUTE:
......@@ -1669,6 +1866,7 @@ class Scheduler:
status: SequenceStatus,
enable_chunking: bool,
budget: SchedulingBudget,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> Tuple[int, int]:
"""
Returns the number of new uncached and cached tokens to schedule for a
......@@ -1692,6 +1890,8 @@ class Scheduler:
to schedule.
enable_chunking: Whether to chunk the number of tokens to compute.
budget: The budget to chunk the number of tokens to compute.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
......@@ -1769,6 +1969,8 @@ class Scheduler:
budget,
self._get_prompt_limit(seq_group),
num_uncached_new_tokens,
self.partial_prefill_budget_lookup_list,
partial_prefill_metadata,
)
return num_uncached_new_tokens, num_cached_new_tokens
......@@ -1780,6 +1982,8 @@ class Scheduler:
budget: SchedulingBudget,
prompt_limit: int,
num_new_tokens: int,
partial_prefill_budget_lookup_list: List[int],
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> int:
"""
Chunks the number of new tokens to schedule based on the budget when
......@@ -1812,29 +2016,31 @@ class Scheduler:
# the sequence.
return num_new_tokens
return (0 if num_new_tokens > remaining_token_budget else
num_new_tokens)
return 0 if num_new_tokens > \
remaining_token_budget else num_new_tokens
if cache_config.enable_prefix_caching:
# Adjust the remaining token budget to be divisible by the block
# size when prefix caching is enabled.
# Get the number of tokens to allocate to this prefill slot
prefill_slot_budget = (
remaining_token_budget if partial_prefill_metadata is None else
partial_prefill_budget_lookup_list[
partial_prefill_metadata.schedulable_prefills])
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block
# size to avoid partial block matching.
if cache_config.enable_prefix_caching:
# When prefix caching is enabled and we're partially prefilling
# a sequence, we always allocate a number of new tokens that is
# divisible by the block size to avoid partial block matching.
block_size = cache_config.block_size
remainder = budget.token_budget % block_size
if remainder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {remainder}")
# Round down to block size.
remaining_token_budget = (remaining_token_budget // block_size *
block_size)
num_new_tokens = min(num_new_tokens, remaining_token_budget)
# Don't exceed either the total budget or slot budget.
# Take min of those and get the next lowest multiple of the
# block size:
remaining_token_budget = (
min(remaining_token_budget, prefill_slot_budget) //
block_size) * block_size
# NB: In the case where num_new_tokens < budget, we are
# finishing prefill for this sequence, so we do not need to
# allocate a full block.
num_new_tokens = min(num_new_tokens, remaining_token_budget,
prefill_slot_budget)
return num_new_tokens
......@@ -9,7 +9,7 @@
# the only successful approach is to call cuda driver API in C.
import dataclasses
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
......@@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool
yield mem_pool, new_alloc
class CuMemAllocator:
......@@ -142,6 +142,7 @@ class CuMemAllocator:
def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
......@@ -231,7 +232,13 @@ class CuMemAllocator:
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback):
self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
class CpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
self.ipex_available = False
self.dist_module = torch.distributed
try:
import intel_extension_for_pytorch as ipex
self.ipex_available = True
self.dist_module = ipex.distributed
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU (e.g. MacOS)
"""
pass
def all_reduce(self, input_):
return self.dist_module.all_reduce(input_, group=self.device_group)
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
class CudaCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if "pp" in unique_name:
# pipeline parallel does not need custom allreduce
use_custom_allreduce = False
else:
from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_pynccl = True
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
......@@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
......@@ -105,8 +106,13 @@ class CudaRTLibrary:
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
assert so_file is not None, \
"libcudart is not loaded in the current process"
(
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
)
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
......
......@@ -2,45 +2,40 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401
class HpuCommunicator:
def __init__(self, group: ProcessGroup):
if not current_platform.is_hpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
class HpuCommunicator(DeviceCommunicatorBase):
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
dist.all_reduce(x, group=self.group)
return x
dist.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += x.dim()
input_size = x.size()
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=x.dtype,
device=x.device)
dtype=input_.dtype,
device=input_.device)
# All-gather.
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
......
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
......@@ -16,19 +18,20 @@ if current_platform.is_tpu():
from vllm.executor import ray_utils
class TpuCommunicator:
class TpuCommunicator(DeviceCommunicatorBase):
def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
global_rank = self.global_rank
global_world_size = self.global_world_size
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
......@@ -55,9 +58,9 @@ class TpuCommunicator:
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, input_)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
return xm.all_gather(input_, dim=dim)
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
class XpuCommunicator:
class XpuCommunicator(DeviceCommunicatorBase):
def __init__(self, group: ProcessGroup):
if not current_platform.is_xpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x
def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
def gather(self,
input_: torch.Tensor,
rank_in_group: int,
dst: int = 0,
dim: int = -1):
dim: int = -1) -> Optional[torch.Tensor]:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size = input_.size()
......@@ -34,10 +39,10 @@ class XpuCommunicator:
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.group)
if rank_in_group == dst:
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
if self.rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
......
......@@ -14,8 +14,8 @@ The KV cache transfer contains three layer of abstractions:
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
RDMA database).
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
......@@ -27,4 +27,3 @@ The example usage is in [this file](../../../examples/online_serving/disaggregat
Here is the diagram of how we run disaggretgated prefilling.
![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg)
......@@ -10,7 +10,6 @@
stop the prefill instance when the decode instance is slow.
"""
import threading
import time
from collections import deque
from typing import Deque, List, Optional, Union
......@@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
......@@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase):
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock()
self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
......@@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()
with self.buffer_lock:
for data in buffer_item:
self.buffer_size += self._get_element_size(data)
self.buffer_size += data_size
self.buffer.append(buffer_item)
self.buffer_cv.notify()
def _is_end_signal(self, signal):
return signal is None
......@@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase):
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]
matched_length = 0
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with self.buffer_lock:
def is_buffer_available(
tokens_roi_recver: List[torch.Tensor], ) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for _ in range(len(self.buffer)):
temp_length = self._matches(self.buffer[0],
tokens_roi_recver)
if temp_length > 0:
matched_length = temp_length
break
if self._matches(self.buffer[0],
tokens_roi_recver) > 0:
return True
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
if matched_length > 0:
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
else:
# no match, just send None
for _ in range(5):
self.data_pipe.send_tensor(None)
return False
with self.buffer_cv:
while not is_buffer_available(tokens_roi_recver):
logger.debug(
"KV transfer buffer is not available. Waiting...")
self.buffer_cv.wait()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()
except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
......@@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase):
return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
......
......@@ -39,9 +39,12 @@ from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import direct_register_custom_op, supports_custom_op
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op)
if TYPE_CHECKING:
from vllm.config import VllmConfig
......@@ -130,9 +133,8 @@ class GroupCoordinator:
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to
a specific implementation (e.g. switch allreduce implementation
based on the tensor size and cuda graph mode).
the processes in the group. It manages both CPU and device
communication.
"""
# available attributes:
......@@ -150,11 +152,8 @@ class GroupCoordinator:
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
use_device_communicator: bool # whether to use device communicator
device_communicator: DeviceCommunicatorBase # device communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
......@@ -162,11 +161,7 @@ class GroupCoordinator:
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_device_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
......@@ -196,56 +191,26 @@ class GroupCoordinator:
assert self.device_group is not None
from vllm.platforms import current_platform
# TODO: fix it for other platforms
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
self.use_device_communicator = use_device_communicator
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls())
self.device_communicator = device_comm_cls(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
)
from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator] = None
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
from vllm.distributed.device_communicators.hpu_communicator import (
HpuCommunicator)
self.hpu_communicator: Optional[HpuCommunicator]
if use_hpu_communicator and self.world_size > 1:
self.hpu_communicator = HpuCommunicator(group=self.device_group)
from vllm.distributed.device_communicators.xpu_communicator import (
XpuCommunicator)
self.xpu_communicator: Optional[XpuCommunicator]
if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group)
from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
......@@ -253,6 +218,9 @@ class GroupCoordinator:
self.mq_broadcaster = MessageQueue.create_from_process_group(
self.cpu_group, 1 << 22, 6)
from vllm.platforms import current_platform
self.use_custom_op_call = current_platform.is_cuda_alike()
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
......@@ -296,9 +264,16 @@ class GroupCoordinator:
else:
stream = graph_capture_context.stream
ca_comm = self.ca_comm
maybe_ca_context = nullcontext(
) if ca_comm is None else ca_comm.capture()
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
if self.device_communicator is not None:
assert isinstance(self.device_communicator, CudaCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
......@@ -328,54 +303,14 @@ class GroupCoordinator:
if self.world_size == 1:
return input_
if input_.is_cpu:
try:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU
"""
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self.tpu_communicator.all_reduce(input_)
if self.hpu_communicator is not None and \
not self.hpu_communicator.disabled:
return self.hpu_communicator.all_reduce(input_)
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
if self.use_custom_op_call:
return torch.ops.vllm.all_reduce(input_,
group_name=self.unique_name)
else:
return self._all_reduce_out_place(input_)
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
# always try custom allreduce first,
# and then pynccl.
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
return self.device_communicator.all_reduce(input_)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
......@@ -385,40 +320,7 @@ class GroupCoordinator:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
return self.device_communicator.all_gather(input_, dim)
def gather(self,
input_: torch.Tensor,
......@@ -433,30 +335,7 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.gather(input_, self.rank_in_group,
dst, dim)
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
return self.device_communicator.gather(input_, dst, dim)
def broadcast(self, input_: torch.Tensor, src: int = 0):
"""Broadcast the input tensor.
......@@ -798,14 +677,7 @@ class GroupCoordinator:
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
self.device_communicator.send(tensor, dst)
def recv(self,
size: torch.Size,
......@@ -813,16 +685,7 @@ class GroupCoordinator:
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
return self.device_communicator.recv(size, dtype, src)
def destroy(self):
if self.device_group is not None:
......@@ -831,10 +694,8 @@ class GroupCoordinator:
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.device_communicator is not None:
self.device_communicator.destroy()
if self.mq_broadcaster is not None:
self.mq_broadcaster = None
......@@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int,
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_device_communicator=False,
group_name="world",
)
......@@ -866,23 +723,15 @@ def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
from vllm.platforms import current_platform
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=current_platform.is_cuda_alike(),
use_custom_allreduce=current_platform.is_cuda_alike()
and use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_device_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
......@@ -1024,13 +873,6 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
if (world_size
!= tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size)
......@@ -1060,11 +902,9 @@ def initialize_model_parallel(
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False,
group_name="pp")
......
......@@ -20,6 +20,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.plugins import load_general_plugins
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean
......@@ -119,6 +120,9 @@ class EngineArgs:
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_partial_prefills: Optional[int] = 1
max_long_partial_prefills: Optional[int] = 1
long_prefill_token_threshold: Optional[int] = 0
max_num_seqs: Optional[int] = None
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
......@@ -191,6 +195,7 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None
......@@ -206,6 +211,7 @@ class EngineArgs:
calculate_kv_scales: Optional[bool] = None
additional_config: Optional[Dict[str, Any]] = None
moe_ep_size: int = 1
def __post_init__(self):
......@@ -286,11 +292,13 @@ class EngineArgs:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow', 'mistral'],
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
'"mistral" will always use the `mistral_common` tokenizer. \n* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
......@@ -520,6 +528,31 @@ class EngineArgs:
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills."
"Defaults to 1",
)
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency. Defaults to 1.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens. Defaults to 4%% of "
"the model's context length.",
)
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
......@@ -929,6 +962,13 @@ class EngineArgs:
'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).')
parser.add_argument(
'--scheduler-cls',
default=EngineArgs.scheduler_cls,
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
'is the default scheduler. Can be a class directly or the path to '
'a class of form "mod.custom_class".')
parser.add_argument(
'--override-neuron-config',
type=json.loads,
......@@ -1008,6 +1048,14 @@ class EngineArgs:
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
parser.add_argument(
"--additional-config",
type=json.loads,
default=None,
help="Additional config for specified platform in JSON format. "
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{\"config_key\":\"config_value\"}'")
return parser
@classmethod
......@@ -1068,6 +1116,9 @@ class EngineArgs:
def create_engine_config(self,
usage_context: Optional[UsageContext] = None
) -> VllmConfig:
from vllm.platforms import current_platform
current_platform.pre_register_and_update()
if envs.VLLM_USE_V1:
self._override_v1_engine_args(usage_context)
......@@ -1254,7 +1305,13 @@ class EngineArgs:
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
policy=self.scheduling_policy)
policy=self.scheduling_policy,
scheduler_cls=self.scheduler_cls,
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
)
lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
......@@ -1315,6 +1372,7 @@ class EngineArgs:
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
additional_config=self.additional_config,
)
if envs.VLLM_USE_V1:
......@@ -1375,6 +1433,12 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')
# Initialize plugin to update the parser, for example, The plugin may
# adding a new kind of quantization method to --quantization argument or
# a new device to --device argument.
load_general_plugins()
from vllm.platforms import current_platform
current_platform.pre_register_and_update(parser)
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment