Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
...@@ -3,14 +3,16 @@ ...@@ -3,14 +3,16 @@
import argparse import argparse
import dataclasses import dataclasses
import json import json
import os
import random import random
import time import time
from functools import cache from functools import cache
from typing import Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import uvloop import uvloop
from benchmark_utils import convert_to_pytorch_benchmark_format
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
...@@ -361,6 +363,25 @@ def run_mii( ...@@ -361,6 +363,25 @@ def run_mii(
return end - start return end - start
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: Dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"requests_per_second": [results["requests_per_second"]],
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
k: results[k]
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
})
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
with open(pt_file, "w") as f:
json.dump(pt_records, f)
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
random.seed(args.seed) random.seed(args.seed)
...@@ -459,6 +480,7 @@ def main(args: argparse.Namespace): ...@@ -459,6 +480,7 @@ def main(args: argparse.Namespace):
} }
with open(args.output_json, "w") as f: with open(args.output_json, "w") as f:
json.dump(results, f, indent=4) json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
if __name__ == "__main__": if __name__ == "__main__":
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import ast import ast
import copy
import dataclasses import dataclasses
import os import os
import pprint import pprint
import time import time
from collections import defaultdict
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch from unittest.mock import patch
...@@ -19,6 +17,7 @@ from vllm.config import CompilationConfig, VllmConfig ...@@ -19,6 +17,7 @@ from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors from vllm.utils import weak_ref_tensors
from .compiler_interface import EagerAdaptor, InductorAdaptor
from .counter import compilation_counter from .counter import compilation_counter
from .inductor_pass import InductorPass from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile from .monitor import end_monitoring_torch_compile
...@@ -27,306 +26,128 @@ from .pass_manager import PostGradPassManager ...@@ -27,306 +26,128 @@ from .pass_manager import PostGradPassManager
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclasses.dataclass class CompilerManager:
class InductorArtifact: """
hash_str: str = "" A manager to manage the compilation process, including
file_path: str = "" caching the compiled graph, loading the compiled graph,
and compiling the graph.
The cache is a dict mapping
`(runtime_shape, graph_index, backend_name)`
to `any_data` returned from the compiler.
class InductorHashCache: When serializing the cache, we save it to a Python file
for readability. We don't use json here because json doesn't
support int as key.
""" """
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str, file_path)
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is def __init__(self, use_inductor: bool):
runtime_shape, and the value is a dict of graph_index to hash_str. self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
cls = InductorAdaptor if use_inductor else EagerAdaptor
self.compiler = cls()
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`, def compute_hash(self, vllm_config: VllmConfig) -> str:
we don't use json here because json doesn't support int as key. return self.compiler.compute_hash(vllm_config)
TODO: better off-the-shelf solution to serialize the data?
"""
def __init__(self, cache_dir: str, disabled: bool = False): def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
self.cache: Dict[Optional[int], self.disable_cache = disable_cache
Dict[int, InductorArtifact]] = defaultdict(dict)
self.disabled = disabled
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir, self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
"inductor_hash_cache.py")
if disabled: if not disable_cache and os.path.exists(self.cache_file_path):
return # load the cache from the file
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
if os.path.exists(self.cache_file_path):
with open(self.cache_file_path) as f: with open(self.cache_file_path) as f:
self.deserialize(f.read()) # we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
def deserialize(self, data: str): # do not use eval(), it is unsafe.
# we use ast.literal_eval to parse the data self.cache = ast.literal_eval(f.read())
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe. self.compiler.initialize_cache(cache_dir=cache_dir,
list_data = ast.literal_eval(data) disable_cache=disable_cache)
for item in list_data:
runtime_shape = item[0]
graph_index = item[1]
hash_str = item[2]
# for compatibility of old version,
# where we don't have file_path.
# NOTE: after running the new code, the file_path
# will be updated.
file_path = "" if len(item) == 3 else item[3]
self.cache[runtime_shape][graph_index] = InductorArtifact(
hash_str=hash_str, file_path=file_path)
def serialize(self) -> str:
data = []
for runtime_shape, value in self.cache.items():
for graph_index, inductor_artifact in value.items():
data.append(
(runtime_shape, graph_index, inductor_artifact.hash_str,
inductor_artifact.file_path))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)
def save_to_file(self): def save_to_file(self):
if self.disabled: if self.disable_cache:
return return
with open(self.cache_file_path, "w") as f: with open(self.cache_file_path, "w") as f:
f.write(self.serialize()) printer = pprint.PrettyPrinter(indent=4)
data = printer.pformat(self.cache)
def __contains__(self, key: Tuple[Optional[int], int]) -> bool: f.write(data)
if self.disabled:
return False def load(self,
runtime_shape, graph_index = key graph: fx.GraphModule,
return runtime_shape in self.cache and graph_index in self.cache[ example_inputs: List[Any],
runtime_shape] graph_index: int,
runtime_shape: Optional[int] = None) -> Optional[Callable]:
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact: if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
if self.disabled: return None
raise KeyError("cannot read from disabled cache") handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
runtime_shape, graph_index = key compiled_graph = self.compiler.load(handle, graph, example_inputs,
return self.cache[runtime_shape][graph_index] graph_index, runtime_shape)
def __setitem__(self, key: Tuple[Optional[int], int],
value: InductorArtifact):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
self.cache[runtime_shape][graph_index] = value
class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def __init__(self) -> None:
self.guards: List[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
return True
def get_pruned_guards(self, *args, **kwargs):
return []
def produce_guards_expression(self, *args, **kwargs):
return ""
def wrap_inductor(graph: fx.GraphModule,
example_inputs,
additional_inductor_config,
compilation_config: CompilationConfig,
vllm_backend: "VllmBackend",
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None,
use_inductor: bool = True) -> Any:
if graph_index == 0:
# before compiling the first graph, record the start time
global compilation_start_time
compilation_start_time = time.time()
if not use_inductor:
return graph
compilation_counter.num_inductor_compilations += 1
from torch._inductor import config
current_config = config.get_config_copy()
from torch._inductor.compile_fx import compile_fx
if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config["max_autotune"] = True
current_config["coordinate_descent_tuning"] = True
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
cache_data = vllm_backend.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
inductor_artifact = cache_data[(runtime_shape, graph_index)]
hash_str = inductor_artifact.hash_str
if graph_index == 0:
# adds some info logging for the first graph
logger.info(
"Directly lookup the graph for shape %s from the cache",
str(runtime_shape)) # noqa
logger.debug( logger.debug(
"directly lookup the %s-th graph for shape %s via hash %s", "Directly load the %s-th graph for shape %s from %s via "
graph_index, str(runtime_shape), hash_str) "handle %s", graph_index, str(runtime_shape), self.compiler.name,
from torch._inductor.codecache import FxGraphCache handle)
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", return compiled_graph
lambda *args, **kwargs: AlwaysHitShapeEnv()):
inductor_compiled_graph = FxGraphCache._lookup_graph( def compile(self,
hash_str, example_inputs, True, False) graph: fx.GraphModule,
assert inductor_compiled_graph is not None, ( example_inputs,
"Inductor cache lookup failed. Please remove" additional_inductor_config,
f"the cache file {cache_data.cache_file_path} and try again." # noqa compilation_config: CompilationConfig,
) graph_index: int = 0,
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa num_graphs: int = 1,
runtime_shape: Optional[int] = None) -> Any:
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
else:
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
inductor_artifact = InductorArtifact()
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
original_load = FxGraphCache.load
def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
return inductor_compiled_graph
def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
inductor_artifact.hash_str = out[0]
return out
def _check_can_cache(*args, **kwargs):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv()
with ExitStack() as stack:
if not cache_data.disabled:
# compilation cache is enabled, patch several functions
# hijack to get the compiled graph itself
stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache.load",
hijack_load))
# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash))
# for providing a dummy shape environment
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache))
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)
# store the inductor_artifact in the cache
cache_data[(runtime_shape, graph_index)] = inductor_artifact
if graph_index == 0: if graph_index == 0:
# adds some info logging for the first graph # before compiling the first graph, record the start time
logger.info("Cache the graph of shape %s for later use", global compilation_start_time
str(runtime_shape)) compilation_start_time = time.time()
logger.debug(
"store the %s-th graph for shape %s via hash %s from file %s", compilation_counter.num_backend_compilations += 1
graph_index, str(runtime_shape), inductor_artifact.hash_str,
inductor_artifact.file_path) compiled_graph = None
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1: # try to load from the cache
now = time.time() compiled_graph = self.load(graph, example_inputs, graph_index,
elapsed = now - compilation_start_time runtime_shape)
compilation_config.compilation_time += elapsed if compiled_graph is not None:
if runtime_shape is None: if graph_index == 0:
logger.info("Compiling a graph for general shape takes %.2f s", # adds some info logging for the first graph
elapsed) logger.info("Directly load the compiled graph for shape %s "
else: "from the cache", str(runtime_shape)) # noqa
logger.info("Compiling a graph for shape %s takes %.2f s", return compiled_graph
runtime_shape, elapsed)
# no compiler cached the graph, or the cache is disabled,
# we need to compile it
compiled_graph, handle = self.compiler.compile(
graph, example_inputs, additional_inductor_config, runtime_shape)
assert compiled_graph is not None, "Failed to compile the graph"
# store the artifact in the cache
if handle is not None:
self.cache[(runtime_shape, graph_index,
self.compiler.name)] = handle
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug(
"store the %s-th graph for shape %s from %s via handle %s",
graph_index, str(runtime_shape), self.compiler.name, handle)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info("Compiling a graph for general shape takes %.2f s",
elapsed)
else:
logger.info("Compiling a graph for shape %s takes %.2f s",
runtime_shape, elapsed)
return compiled_graph return compiled_graph
@dataclasses.dataclass @dataclasses.dataclass
...@@ -436,16 +257,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -436,16 +257,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i for i, x in enumerate(args) if isinstance(x, torch.SymInt) i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
] ]
global compilation_start_time global compilation_start_time
compiled_graph_for_general_shape = wrap_inductor( compiled_graph_for_general_shape = self.vllm_backend.\
compiler_manager.compile(
submod, submod,
args, args,
self.compilation_config.inductor_compile_config, self.compilation_config.inductor_compile_config,
self.compilation_config, self.compilation_config,
self.vllm_backend,
graph_index=index, graph_index=index,
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
runtime_shape=None, runtime_shape=None)
use_inductor=self.compilation_config.use_inductor)
self.module.__dict__[target] = PiecewiseBackend( self.module.__dict__[target] = PiecewiseBackend(
submod, self.vllm_config, self.graph_pool, index, submod, self.vllm_config, self.graph_pool, index,
...@@ -483,7 +303,7 @@ class VllmBackend: ...@@ -483,7 +303,7 @@ class VllmBackend:
post_grad_passes: Sequence[Callable] post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int] sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor] input_buffers: List[torch.Tensor]
inductor_hash_cache: InductorHashCache compiler_manager: CompilerManager
def __init__( def __init__(
self, self,
...@@ -507,6 +327,9 @@ class VllmBackend: ...@@ -507,6 +327,9 @@ class VllmBackend:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.compiler_manager: CompilerManager = CompilerManager(
self.compilation_config.use_inductor)
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
...@@ -533,9 +356,11 @@ class VllmBackend: ...@@ -533,9 +356,11 @@ class VllmBackend:
# the cache dir will be the same so that we can reuse the compiled # the cache dir will be the same so that we can reuse the compiled
# graph. # graph.
factors = []
# 1. factors come from the vllm_config (it mainly summarizes how the # 1. factors come from the vllm_config (it mainly summarizes how the
# model is created) # model is created)
config_hash = vllm_config.compute_hash() config_hash = vllm_config.compute_hash()
factors.append(config_hash)
# 2. factors come from the code files that are traced by Dynamo ( # 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass) # it mainly summarizes how the model is used in forward pass)
...@@ -553,10 +378,15 @@ class VllmBackend: ...@@ -553,10 +378,15 @@ class VllmBackend:
import hashlib import hashlib
code_hash = hashlib.md5( code_hash = hashlib.md5(
"\n".join(hash_content).encode()).hexdigest() "\n".join(hash_content).encode()).hexdigest()
factors.append(code_hash)
# 3. compiler hash
compiler_hash = self.compiler_manager.compute_hash(vllm_config)
factors.append(compiler_hash)
# combine all factors to generate the cache dir
hash_key = hashlib.md5(str(factors).encode()).hexdigest()[:10]
# combine the two hashes to generate the cache dir
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
cache_dir = os.path.join( cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, envs.VLLM_CACHE_ROOT,
"torch_compile_cache", "torch_compile_cache",
...@@ -570,15 +400,16 @@ class VllmBackend: ...@@ -570,15 +400,16 @@ class VllmBackend:
cache_dir, f"rank_{vllm_config.parallel_config.rank}") cache_dir, f"rank_{vllm_config.parallel_config.rank}")
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
disabled = envs.VLLM_DISABLE_COMPILE_CACHE disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
local_cache_dir, disabled=disabled) if disable_cache:
if disabled:
logger.info("vLLM's torch.compile cache is disabled.") logger.info("vLLM's torch.compile cache is disabled.")
else: else:
logger.info("Using cache directory: %s for vLLM's torch.compile", logger.info("Using cache directory: %s for vLLM's torch.compile",
local_cache_dir) local_cache_dir)
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
# when dynamo calls the backend, it means the bytecode # when dynamo calls the backend, it means the bytecode
# transform and analysis are done # transform and analysis are done
compilation_counter.num_graphs_seen += 1 compilation_counter.num_graphs_seen += 1
...@@ -759,7 +590,7 @@ class PiecewiseBackend: ...@@ -759,7 +590,7 @@ class PiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile # no specific sizes to compile
# save the hash of the inductor graph for the next run # save the hash of the inductor graph for the next run
self.vllm_backend.inductor_hash_cache.save_to_file() self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config) end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any: def __call__(self, *args) -> Any:
...@@ -782,16 +613,14 @@ class PiecewiseBackend: ...@@ -782,16 +613,14 @@ class PiecewiseBackend:
entry.compiled = True entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape) self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments # args are real arguments
entry.runnable = wrap_inductor( entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph, self.graph,
args, args,
self.compilation_config.inductor_compile_config, self.compilation_config.inductor_compile_config,
self.compilation_config, self.compilation_config,
self.vllm_backend,
graph_index=self.piecewise_compile_index, graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape, runtime_shape=runtime_shape)
use_inductor=self.compilation_config.use_inductor)
# finished compilations for all required shapes # finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
......
# SPDX-License-Identifier: Apache-2.0
import copy
import hashlib
import os
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
import torch
import torch._inductor.compile_fx
import torch.fx as fx
from vllm.config import VllmConfig
class CompilerInterface:
"""
The interface for a compiler that can be used by vLLM.
"""
# The name of the compiler, e.g. inductor.
# This is a class-level attribute.
name: str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
"""
pass
def compute_hash(self, vllm_config: VllmConfig) -> str:
"""
Gather all the relevant information from the VLLM config,
to compute a hash so that we can cache the compiled model.
See :meth:`VllmConfig.compute_hash` to check what information
is already considered by default. This function should only
consider the information that is specific to the compiler.
"""
return ""
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
) -> Tuple[Optional[Callable], Optional[Any]]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
the `example_inputs` have a dynamic shape. Otherwise, the
`runtime_shape` specifies the shape of the inputs. Right now we only
support one variable shape for all inputs, which is the batchsize
(number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
The function should return a compiled callable function, as well as
a handle that can be used to directly load the compiled function.
The handle should be a plain Python object, preferably a string or a
file path for readability.
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
"""
return None, None
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
The handle is the second return value of the `compile` function.
"""
raise NotImplementedError("caching is not supported")
class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def __init__(self) -> None:
self.guards: List[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
return True
def get_pruned_guards(self, *args, **kwargs):
return []
def produce_guards_expression(self, *args, **kwargs):
return ""
class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5 and 2.6.
"""
name = "inductor"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
if disable_cache:
return
# redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
) -> Tuple[Optional[Callable], Optional[Any]]:
from torch._inductor import config
current_config = config.get_config_copy()
from torch._inductor.compile_fx import compile_fx
# disable remote cache
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
if compiler_config is not None:
current_config.update(compiler_config)
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config["max_autotune"] = True
current_config["coordinate_descent_tuning"] = True
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
hash_str, file_path = None, None
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
return inductor_compiled_graph
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
elif torch.__version__ >= "2.6":
# function renamed in 2.6
original_load_name = None
def hijacked_compile_fx_inner(*args, **kwargs):
output = torch._inductor.compile_fx.compile_fx_inner(
*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
if inductor_compiled_graph is not None:
nonlocal file_path
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str
hash_str = out[0]
return out
def _check_can_cache(*args, **kwargs):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv()
with ExitStack() as stack:
# hijack to get the compiled graph itself
if original_load_name is not None:
stack.enter_context(patch(original_load_name, hijack_load))
# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash))
# for providing a dummy shape environment
stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache))
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config)
assert hash_str is not None, (
"failed to get the hash of the compiled graph")
assert file_path is not None, (
"failed to get the file path of the compiled graph")
return compiled_graph, (hash_str, file_path)
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
if torch.__version__.startswith("2.5"):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache directory and try again." # noqa
)
elif torch.__version__ >= "2.6":
from torch._inductor.output_code import (
CompiledFxGraphConstantsWithGm)
constants = CompiledFxGraphConstantsWithGm(graph)
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, None, constants)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache directory and try again." # noqa
)
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph
class EagerAdaptor(CompilerInterface):
name = "eager"
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
) -> Tuple[Optional[Callable], Optional[Any]]:
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
return graph, None
...@@ -13,7 +13,7 @@ class CompilationCounter: ...@@ -13,7 +13,7 @@ class CompilationCounter:
num_piecewise_graphs_seen: int = 0 num_piecewise_graphs_seen: int = 0
# not including the splitting ops # not including the splitting ops
num_piecewise_capturable_graphs_seen: int = 0 num_piecewise_capturable_graphs_seen: int = 0
num_inductor_compilations: int = 0 num_backend_compilations: int = 0
num_cudagraph_caputured: int = 0 num_cudagraph_caputured: int = 0
def clone(self) -> "CompilationCounter": def clone(self) -> "CompilationCounter":
......
...@@ -13,7 +13,6 @@ from torch import fx ...@@ -13,7 +13,6 @@ from torch import fx
class InductorPass(ABC): class InductorPass(ABC):
""" """
General custom inductor pass interface. General custom inductor pass interface.
TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass
""" """
@abstractmethod @abstractmethod
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from typing import Any, Dict, List from typing import Any, Dict, List
import torch
from torch import fx as fx from torch import fx as fx
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
...@@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass ...@@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass
logger = init_logger(__name__) logger = init_logger(__name__)
class PostGradPassManager: class PlaceHolder:
pass
if torch.__version__ < "2.6":
Parent = PlaceHolder # type: ignore
else:
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
class PostGradPassManager(Parent):
""" """
The pass manager for post-grad passes. The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes. It handles configuration, adding custom passes, and running passes.
...@@ -55,6 +66,9 @@ class PostGradPassManager: ...@@ -55,6 +66,9 @@ class PostGradPassManager:
assert isinstance(pass_, InductorPass) assert isinstance(pass_, InductorPass)
self.passes.append(pass_) self.passes.append(pass_)
def uuid(self):
return self.__getstate__()
def __getstate__(self) -> Dict[str, List[Any]]: def __getstate__(self) -> Dict[str, List[Any]]:
""" """
Custom pickling for the pass manager, as some passes cannot be pickled. Custom pickling for the pass manager, as some passes cannot be pickled.
......
...@@ -54,17 +54,18 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 ...@@ -54,17 +54,18 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward"] "score", "reward", "transcription"]
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", _ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
"draft"] "draft", "transcription"]
RunnerType = Literal["generate", "pooling", "draft"] RunnerType = Literal["generate", "pooling", "draft", "transcription"]
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = { _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
"generate": ["generate"], "generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"], "pooling": ["embed", "classify", "score", "reward"],
"draft": ["draft"], "draft": ["draft"],
"transcription": ["transcription"],
} }
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { _TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
...@@ -102,8 +103,9 @@ class ModelConfig: ...@@ -102,8 +103,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use. it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use. tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`. "mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer. downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or allowed_local_media_path: Allowing API requests to read local images or
...@@ -407,7 +409,8 @@ class ModelConfig: ...@@ -407,7 +409,8 @@ class ModelConfig:
if is_s3(model) or is_s3(tokenizer): if is_s3(model) or is_s3(tokenizer):
if is_s3(model): if is_s3(model):
s3_model = S3Model() s3_model = S3Model()
s3_model.pull_files(model, allow_pattern=["*config.json"]) s3_model.pull_files(
model, allow_pattern=["*.model", "*.py", "*.json"])
self.model_weights = self.model self.model_weights = self.model
self.model = s3_model.dir self.model = s3_model.dir
...@@ -467,10 +470,10 @@ class ModelConfig: ...@@ -467,10 +470,10 @@ class ModelConfig:
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower() tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]: if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
raise ValueError( raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow' or 'mistral'.") "either 'auto', 'slow', 'mistral' or 'custom'.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _get_preferred_task( def _get_preferred_task(
...@@ -483,6 +486,8 @@ class ModelConfig: ...@@ -483,6 +486,8 @@ class ModelConfig:
return "embed" return "embed"
if ModelRegistry.is_cross_encoder_model(architectures): if ModelRegistry.is_cross_encoder_model(architectures):
return "score" return "score"
if ModelRegistry.is_transcription_model(architectures):
return "transcription"
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
# Other models follow this pattern # Other models follow this pattern
...@@ -515,6 +520,8 @@ class ModelConfig: ...@@ -515,6 +520,8 @@ class ModelConfig:
runner_support: Dict[RunnerType, bool] = { runner_support: Dict[RunnerType, bool] = {
# NOTE: Listed from highest to lowest priority, # NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them # in case the model supports multiple of them
"transcription":
ModelRegistry.is_transcription_model(architectures),
"generate": ModelRegistry.is_text_generation_model(architectures), "generate": ModelRegistry.is_text_generation_model(architectures),
"pooling": ModelRegistry.is_pooling_model(architectures), "pooling": ModelRegistry.is_pooling_model(architectures),
} }
...@@ -756,7 +763,7 @@ class ModelConfig: ...@@ -756,7 +763,7 @@ class ModelConfig:
def is_deepseek_mla(self) -> bool: def is_deepseek_mla(self) -> bool:
return (hasattr(self.hf_text_config, "model_type")) \ return (hasattr(self.hf_text_config, "model_type")) \
and (self.hf_text_config.model_type in \ and (self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3'))\ ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
and (self.hf_text_config.kv_lora_rank is not None) and (self.hf_text_config.kv_lora_rank is not None)
def get_head_size(self) -> int: def get_head_size(self) -> int:
...@@ -849,8 +856,12 @@ class ModelConfig: ...@@ -849,8 +856,12 @@ class ModelConfig:
def get_layers_start_end_indices( def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]: self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config, if self.hf_text_config.model_type == "deepseek_mtp":
"num_hidden_layers", 0) total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
...@@ -985,37 +996,7 @@ class ModelConfig: ...@@ -985,37 +996,7 @@ class ModelConfig:
@property @property
def use_mla(self) -> bool: def use_mla(self) -> bool:
if not self.is_deepseek_mla or envs.VLLM_MLA_DISABLE: return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
return False
if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)
return False
# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if self.quantization == "compressed-tensors":
quant_config = self._parse_quant_hf_config()
for group_name, cfg in quant_config.get("config_groups", {
"": {}
}).items():
act_cfg = cfg.get("input_activations", {})
act_type = None if act_cfg is None else act_cfg.get("type", "")
w_cfg = cfg.get("weights", {})
w_type = None if w_cfg is None else w_cfg.get("type", "")
if act_type != "fp8" or w_type != "fp8":
logger.warning(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.\n "
"Full config: %s", group_name, act_type, w_type,
quant_config)
return False
return True
@property @property
def supported_runner_types(self) -> Set[RunnerType]: def supported_runner_types(self) -> Set[RunnerType]:
...@@ -1403,6 +1384,9 @@ class ParallelConfig: ...@@ -1403,6 +1384,9 @@ class ParallelConfig:
logger.info("Defaulting to use %s for distributed inference", logger.info("Defaulting to use %s for distributed inference",
backend) backend)
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
self._verify_args() self._verify_args()
@property @property
...@@ -1453,6 +1437,17 @@ class SchedulerConfig: ...@@ -1453,6 +1437,17 @@ class SchedulerConfig:
# Maximum length of a sequence (including prompt and generated text). # Maximum length of a sequence (including prompt and generated text).
max_model_len: int = 8192 max_model_len: int = 8192
# Maximum number of sequences that can be partially prefilled concurrently
max_num_partial_prefills: int = 1
# Maximum number of "very long prompt" sequences that can be prefilled
# concurrently (long is defined by long_prefill_threshold)
max_long_partial_prefills: int = 1
# calculate context length that determines which sequences are
# considered "long"
long_prefill_token_threshold: int = 0
# The number of slots to allocate per sequence per # The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative # step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be # decoding to store KV activations of tokens which may or may not be
...@@ -1502,6 +1497,10 @@ class SchedulerConfig: ...@@ -1502,6 +1497,10 @@ class SchedulerConfig:
chunked_prefill_enabled: bool = field(init=False) chunked_prefill_enabled: bool = field(init=False)
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -1560,6 +1559,18 @@ class SchedulerConfig: ...@@ -1560,6 +1559,18 @@ class SchedulerConfig:
self.max_num_batched_tokens) self.max_num_batched_tokens)
self.chunked_prefill_enabled = self.enable_chunked_prefill self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len *
0.04)
logger.info(
"Concurrent partial prefills enabled with "
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
"long_prefill_token_threshold=%d",
self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold)
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -1591,6 +1602,29 @@ class SchedulerConfig: ...@@ -1591,6 +1602,29 @@ class SchedulerConfig:
f"({self.num_scheduler_steps}) must be greater than or " f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.") "equal to 1.")
if self.max_num_partial_prefills < 1:
raise ValueError(
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
"must be greater than or equal to 1.")
elif self.max_num_partial_prefills > 1:
if not self.chunked_prefill_enabled:
raise ValueError("Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1.")
if self.long_prefill_token_threshold > self.max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({self.max_model_len}).")
if (self.max_long_partial_prefills
< 1) or (self.max_long_partial_prefills
> self.max_num_partial_prefills):
raise ValueError(
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
@property @property
def is_multi_step(self) -> bool: def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1 return self.num_scheduler_steps > 1
...@@ -1665,6 +1699,18 @@ class SpeculativeConfig: ...@@ -1665,6 +1699,18 @@ class SpeculativeConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest() hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
return hf_config
@staticmethod @staticmethod
def maybe_create_spec_config( def maybe_create_spec_config(
target_model_config: ModelConfig, target_model_config: ModelConfig,
...@@ -1750,12 +1796,18 @@ class SpeculativeConfig: ...@@ -1750,12 +1796,18 @@ class SpeculativeConfig:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None. the necessary conditions are met, else None.
""" """
if speculative_model is None: if speculative_model is None:
if num_speculative_tokens is not None: if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without " if target_model_config.hf_text_config.model_type \
"speculative_model.") == "deepseek_v3":
return None # use the draft model from the same model:
speculative_model = target_model_config.model
else:
raise ValueError(
"num_speculative_tokens was provided without "
"speculative_model.")
else:
return None
if (speculative_disable_by_batch_size is not None if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2): and speculative_disable_by_batch_size < 2):
...@@ -1809,10 +1861,20 @@ class SpeculativeConfig: ...@@ -1809,10 +1861,20 @@ class SpeculativeConfig:
max_seq_len_to_capture=target_model_config. max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture, max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs, max_logprobs=target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
) )
draft_hf_config = draft_model_config.hf_config draft_hf_config = draft_model_config.hf_config
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
if "eagle-" in draft_model_config.model.lower():
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(draft_model_config.hf_config, EAGLEConfig):
pass
else:
eagle_config = EAGLEConfig(draft_model_config.hf_config)
draft_model_config.hf_config = eagle_config
if (num_speculative_tokens is not None if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")): and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens draft_hf_config.num_lookahead_tokens = num_speculative_tokens
...@@ -1934,8 +1996,9 @@ class SpeculativeConfig: ...@@ -1934,8 +1996,9 @@ class SpeculativeConfig:
speculative_draft_tensor_parallel_size = 1 speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1: if target_parallel_config.tensor_parallel_size > 1:
logger.warning( logger.warning(
"MLPSpeculator cannot currently be run with tp>1; " "%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1") "setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type)
else: else:
speculative_draft_tensor_parallel_size = \ speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size target_parallel_config.tensor_parallel_size
...@@ -3070,7 +3133,8 @@ class VllmConfig: ...@@ -3070,7 +3133,8 @@ class VllmConfig:
kv_transfer_config: KVTransferConfig = field(default=None, kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore init=True) # type: ignore
# some opaque config, only used to provide additional information # some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing and debugging. # for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
additional_config: SupportsHash = field(default=None, additional_config: SupportsHash = field(default=None,
init=True) # type: ignore init=True) # type: ignore
instance_id: str = "" instance_id: str = ""
...@@ -3088,15 +3152,6 @@ class VllmConfig: ...@@ -3088,15 +3152,6 @@ class VllmConfig:
the final hidden states. the final hidden states.
""" """
factors: List[Any] = [] factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
# summarize vllm config # summarize vllm config
vllm_factors: List[Any] = [] vllm_factors: List[Any] = []
......
...@@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest ...@@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus) SequenceStage, SequenceStatus)
from vllm.utils import Device, PyObjectCache from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -39,6 +39,7 @@ class PreemptionMode(enum.Enum): ...@@ -39,6 +39,7 @@ class PreemptionMode(enum.Enum):
recompute them when the sequences are resumed, treating the sequences as recompute them when the sequences are resumed, treating the sequences as
new prompts. new prompts.
""" """
SWAP = enum.auto() SWAP = enum.auto()
RECOMPUTE = enum.auto() RECOMPUTE = enum.auto()
...@@ -54,6 +55,7 @@ class SchedulingBudget: ...@@ -54,6 +55,7 @@ class SchedulingBudget:
happen if we only have chunked prefill scheduling, we can remove this happen if we only have chunked prefill scheduling, we can remove this
feature from the API when chunked prefill is enabled by default. feature from the API when chunked prefill is enabled by default.
""" """
token_budget: int token_budget: int
max_num_seqs: int max_num_seqs: int
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
...@@ -132,6 +134,7 @@ class ScheduledSequenceGroup: ...@@ -132,6 +134,7 @@ class ScheduledSequenceGroup:
@dataclass @dataclass
class SchedulerOutputs: class SchedulerOutputs:
"""The scheduling decision made from a scheduler.""" """The scheduling decision made from a scheduler."""
# Scheduled sequence groups. # Scheduled sequence groups.
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
# Number of prefill groups scheduled. # Number of prefill groups scheduled.
...@@ -205,6 +208,7 @@ class SchedulerRunningOutputs: ...@@ -205,6 +208,7 @@ class SchedulerRunningOutputs:
Could contain prefill (prefill that's chunked) or decodes. If there's not Could contain prefill (prefill that's chunked) or decodes. If there's not
enough memory, it can be preempted (for recompute) or swapped out. enough memory, it can be preempted (for recompute) or swapped out.
""" """
# Selected sequences that are running and in a decoding phase. # Selected sequences that are running and in a decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup] decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are running and in a prefill phase. # Selected sequences that are running and in a prefill phase.
...@@ -246,6 +250,7 @@ class SchedulerSwappedInOutputs: ...@@ -246,6 +250,7 @@ class SchedulerSwappedInOutputs:
Could contain prefill (prefill that's chunked) or decodes. Could contain prefill (prefill that's chunked) or decodes.
""" """
# Selected sequences that are going to be swapped in and is in a # Selected sequences that are going to be swapped in and is in a
# decoding phase. # decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup] decode_seq_groups: List[ScheduledSequenceGroup]
...@@ -280,6 +285,7 @@ class SchedulerPrefillOutputs: ...@@ -280,6 +285,7 @@ class SchedulerPrefillOutputs:
Could contain a fresh prefill requests or preempted requests that need Could contain a fresh prefill requests or preempted requests that need
to be recomputed from scratch. to be recomputed from scratch.
""" """
# Selected sequences for prefill. # Selected sequences for prefill.
seq_groups: List[ScheduledSequenceGroup] seq_groups: List[ScheduledSequenceGroup]
# Ignored sequence groups. # Ignored sequence groups.
...@@ -321,6 +327,100 @@ def scheduled_seq_group_builder(): ...@@ -321,6 +327,100 @@ def scheduled_seq_group_builder():
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
@dataclass
class PartialPrefillMetadata:
"""Holds information about the partial prefills that are currently running
during a single iteration of the Scheduler.
When chunked prefill is enabled, we allow a certain number of seqs to be
partially prefilled during each iteration. Having multiple partial prefills
in flight allows us to minimize TTFT and avoid decode starvation in cases
where a single sequence group with a very large prompt blocks the queue for
too many iterations.
The number of long prefill requests is limited so that smaller
requests may jump the queue in front of them and get to the decode
phase faster.
"""
# A minimum bound on the total number of prefills to be scheduled during
# this iteration
schedulable_prefills: int
# The number of long prefill requests currently running
long_prefills: int
scheduler_config: SchedulerConfig
def can_schedule(self, seq_group: SequenceGroup) -> bool:
"""When concurrent partial prefills are enabled,
we limit the number of long requests and only accept
shorter requests from the queue while running them
concurrently"""
return not (seq_group.first_seq.get_num_new_tokens()
> self.scheduler_config.long_prefill_token_threshold
and self.long_prefills
>= self.scheduler_config.max_long_partial_prefills
and self.scheduler_config.max_num_partial_prefills > 1)
def maybe_increment_partial_prefills(self,
seq_group: SequenceGroup) -> None:
# When a new prefill is scheduled, we need to know if it is a
# long request
if (seq_group.first_seq.get_num_new_tokens()
> self.scheduler_config.long_prefill_token_threshold):
self.long_prefills += 1
@classmethod
def from_queues(
cls,
running: Deque[SequenceGroup],
waiting: Deque[SequenceGroup],
scheduler_config: SchedulerConfig,
) -> "PartialPrefillMetadata":
"""Create a PartialPrefillMetadata object from the current state of
the scheduler's queues.
This accounts for the currently running prefill requests, and peeks into
the waiting queue to see if there are more prefills to potentially be
scheduled during this iteration."""
prefills = 0
long_prefills = 0
waiting_long_prefills = 0
for sg in running:
if sg.first_seq.data.stage == SequenceStage.PREFILL:
prefills += 1
if (sg.first_seq.get_num_new_tokens()
> scheduler_config.long_prefill_token_threshold):
long_prefills += 1
for sg in waiting:
# Don't bother looping through the rest of the queue if we know
# there are already at
# least max_partial_prefills requests to fill
if prefills >= scheduler_config.max_num_partial_prefills:
break
# Don't count long requests from the waiting queue if we aren't
# going to schedule them anyway
if (sg.first_seq.get_num_new_tokens()
> scheduler_config.long_prefill_token_threshold):
if (long_prefills + waiting_long_prefills
>= scheduler_config.max_long_partial_prefills):
continue
waiting_long_prefills += 1
prefills += 1
# NB: long_prefills and waiting_long_prefills are tracked separately.
# We don't account for the waiting requests here because we need to use
# this metadata to track how many have actually been scheduled.
return PartialPrefillMetadata(
schedulable_prefills=min(
prefills, scheduler_config.max_num_partial_prefills),
long_prefills=long_prefills,
scheduler_config=scheduler_config,
)
class Scheduler: class Scheduler:
def __init__( def __init__(
...@@ -360,7 +460,8 @@ class Scheduler: ...@@ -360,7 +460,8 @@ class Scheduler:
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window, sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching) enable_caching=self.cache_config.enable_prefix_caching,
)
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
# Contain new prefill or preempted requests. # Contain new prefill or preempted requests.
...@@ -421,6 +522,18 @@ class Scheduler: ...@@ -421,6 +522,18 @@ class Scheduler:
# for processing and deallocation by the free_finished_seq_groups() # for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = [] self._async_stopped: List[SequenceGroup] = []
# List with the chunk sizes to hand out to each sequence depending
# on how many partial prefills are running. This is slightly faster than
# running an integer division every time a prefill is scheduled.
# This splits the budget evenly among all prefills.
self.partial_prefill_budget_lookup_list = [0] * (
self.scheduler_config.max_num_partial_prefills + 1)
self.partial_prefill_budget_lookup_list[0] = (
scheduler_config.max_num_batched_tokens)
for i in range(1, self.scheduler_config.max_num_partial_prefills + 1):
self.partial_prefill_budget_lookup_list[i] = (
scheduler_config.max_num_batched_tokens // i)
@property @property
def next_cache_id(self): def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters return (self.cache_id + 1) % self.num_cache_iters
...@@ -500,8 +613,8 @@ class Scheduler: ...@@ -500,8 +613,8 @@ class Scheduler:
self.block_manager.free_cross(seq_group) self.block_manager.free_cross(seq_group)
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len( return (len(self.waiting) != 0 or len(self.running) != 0
self.swapped) != 0 or len(self.swapped) != 0)
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device) return self.block_manager.get_prefix_cache_hit_rate(device)
...@@ -523,6 +636,7 @@ class Scheduler: ...@@ -523,6 +636,7 @@ class Scheduler:
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
enable_chunking: bool = False, enable_chunking: bool = False,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> SchedulerRunningOutputs: ) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running. """Schedule sequence groups that are running.
...@@ -537,12 +651,14 @@ class Scheduler: ...@@ -537,12 +651,14 @@ class Scheduler:
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns: Returns:
SchedulerRunningOutputs. SchedulerRunningOutputs.
""" """
ret: SchedulerRunningOutputs = \ ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[
self._scheduler_running_outputs_cache[self.cache_id].get_object() self.cache_id].get_object()
ret.blocks_to_swap_out.clear() ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear() ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear() ret.decode_seq_groups.clear()
...@@ -577,10 +693,14 @@ class Scheduler: ...@@ -577,10 +693,14 @@ class Scheduler:
# 2. If a sequence is running with non-chunked prefill, then # 2. If a sequence is running with non-chunked prefill, then
# there it's a decoding sequence, and the cached tokens info is # there it's a decoding sequence, and the cached tokens info is
# irrelevant. # irrelevant.
num_uncached_new_tokens, _ = ( num_uncached_new_tokens, _ = \
self._get_num_new_uncached_and_cached_tokens( self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking, seq_group,
budget)) SequenceStatus.RUNNING,
enable_chunking,
budget,
partial_prefill_metadata,
)
num_running_tokens = num_uncached_new_tokens num_running_tokens = num_uncached_new_tokens
if num_running_tokens == 0: if num_running_tokens == 0:
...@@ -593,8 +713,8 @@ class Scheduler: ...@@ -593,8 +713,8 @@ class Scheduler:
# to process the final tokens. The check below avoids this extra # to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid # decode run when the model max len is reached, in order to avoid
# a memory overflow. # a memory overflow.
if self.use_async_output_proc and seq_group.seqs[0].get_len( if (self.use_async_output_proc and seq_group.seqs[0].get_len()
) > self.scheduler_config.max_model_len: > self.scheduler_config.max_model_len):
self._async_stopped.append(seq_group) self._async_stopped.append(seq_group)
continue continue
...@@ -653,8 +773,9 @@ class Scheduler: ...@@ -653,8 +773,9 @@ class Scheduler:
self._append_slots(seq_group, blocks_to_copy, enable_chunking) self._append_slots(seq_group, blocks_to_copy, enable_chunking)
is_prefill = seq_group.is_prefill() is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \ scheduled_seq_group: ScheduledSequenceGroup = (
self._scheduled_seq_group_cache[self.cache_id].get_object() self._scheduled_seq_group_cache[
self.cache_id].get_object())
scheduled_seq_group.seq_group = seq_group scheduled_seq_group.seq_group = seq_group
if is_prefill: if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens scheduled_seq_group.token_chunk_size = num_running_tokens
...@@ -731,7 +852,8 @@ class Scheduler: ...@@ -731,7 +852,8 @@ class Scheduler:
logger.warning( logger.warning(
"Failing the request %s because there's not enough kv " "Failing the request %s because there's not enough kv "
"cache blocks to run the entire sequence.", "cache blocks to run the entire sequence.",
seq_group.request_id) seq_group.request_id,
)
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
infeasible_seq_groups.append(seq_group) infeasible_seq_groups.append(seq_group)
...@@ -770,7 +892,6 @@ class Scheduler: ...@@ -770,7 +892,6 @@ class Scheduler:
swapped_queue.popleft() swapped_queue.popleft()
self._swap_in(seq_group, blocks_to_swap_in) self._swap_in(seq_group, blocks_to_swap_in)
self._append_slots(seq_group, blocks_to_copy, enable_chunking) self._append_slots(seq_group, blocks_to_copy, enable_chunking)
is_prefill = seq_group.is_prefill()
if is_prefill: if is_prefill:
prefill_seq_groups.append( prefill_seq_groups.append(
ScheduledSequenceGroup( ScheduledSequenceGroup(
...@@ -801,16 +922,17 @@ class Scheduler: ...@@ -801,16 +922,17 @@ class Scheduler:
) )
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
if self.scheduler_config.chunked_prefill_enabled and \ if (self.scheduler_config.chunked_prefill_enabled
not self.scheduler_config.is_multi_step: and not self.scheduler_config.is_multi_step):
prompt_limit = self.scheduler_config.max_model_len prompt_limit = self.scheduler_config.max_model_len
else: else:
prompt_limit = min(self.scheduler_config.max_model_len, prompt_limit = min(
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens,
)
# Model is fine tuned with long context. Return the fine tuned max_len. # Model is fine tuned with long context. Return the fine tuned max_len.
if (seq_group.lora_request if seq_group.lora_request and seq_group.lora_request.long_lora_max_len:
and seq_group.lora_request.long_lora_max_len):
assert prompt_limit <= seq_group.lora_request.long_lora_max_len assert prompt_limit <= seq_group.lora_request.long_lora_max_len
return seq_group.lora_request.long_lora_max_len return seq_group.lora_request.long_lora_max_len
else: else:
...@@ -818,7 +940,7 @@ class Scheduler: ...@@ -818,7 +940,7 @@ class Scheduler:
def _get_priority(self, def _get_priority(self,
seq_group: SequenceGroup) -> Tuple[Optional[int], float]: seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
""" Get the priority of the sequence group. """Get the priority of the sequence group.
Highest preference to user-defined priority, followed by arrival time. Highest preference to user-defined priority, followed by arrival time.
Args: Args:
seq_group: The sequence group input. seq_group: The sequence group input.
...@@ -851,14 +973,14 @@ class Scheduler: ...@@ -851,14 +973,14 @@ class Scheduler:
if waiting_queue: if waiting_queue:
seq_group = waiting_queue.popleft() seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs() num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens_uncached, _ = ( num_new_tokens_uncached, _ = \
self._get_num_new_uncached_and_cached_tokens( self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, False, budget)) seq_group, SequenceStatus.WAITING, False, budget)
#Only preempt if priority inversion exists # Only preempt if priority inversion exists
while running_queue and self._get_priority( while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group): running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated # Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group) can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens_uncached > 0 if (num_new_tokens_uncached > 0
and can_allocate == AllocStatus.OK and can_allocate == AllocStatus.OK
...@@ -868,7 +990,7 @@ class Scheduler: ...@@ -868,7 +990,7 @@ class Scheduler:
)): )):
break break
#Adjust budget to remove the victim sequence group # Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop() vseq_group = running_queue.pop()
num_running_tokens_uncached, _ = ( num_running_tokens_uncached, _ = (
self._get_num_new_uncached_and_cached_tokens( self._get_num_new_uncached_and_cached_tokens(
...@@ -879,11 +1001,11 @@ class Scheduler: ...@@ -879,11 +1001,11 @@ class Scheduler:
budget.subtract_num_seqs(vseq_group.request_id, budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs) num_running_seqs)
#Preempt out the victim sequence group # Preempt out the victim sequence group
self._preempt(vseq_group, blocks_to_swap_out) self._preempt(vseq_group, blocks_to_swap_out)
waiting_queue.appendleft(vseq_group) waiting_queue.appendleft(vseq_group)
force_preemption_count += 1 force_preemption_count += 1
#Put the sequence back into the waiting queue # Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group) waiting_queue.appendleft(seq_group)
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
...@@ -897,6 +1019,7 @@ class Scheduler: ...@@ -897,6 +1019,7 @@ class Scheduler:
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
enable_chunking: bool = False, enable_chunking: bool = False,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> SchedulerPrefillOutputs: ) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage. """Schedule sequence groups that are in prefill stage.
...@@ -917,10 +1040,20 @@ class Scheduler: ...@@ -917,10 +1040,20 @@ class Scheduler:
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns: Returns:
SchedulerPrefillOutputs. SchedulerPrefillOutputs.
""" """
if budget.remaining_token_budget() == 0:
# Do nothing: Can't add any more prefill anyway
return SchedulerPrefillOutputs(
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking),
)
ignored_seq_groups: List[SequenceGroup] = [] ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = []
...@@ -934,10 +1067,19 @@ class Scheduler: ...@@ -934,10 +1067,19 @@ class Scheduler:
assert len(waiting_seqs) == 1, ( assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt " "Waiting sequence group should have only one prompt "
"sequence.") "sequence.")
if (partial_prefill_metadata is not None
and not partial_prefill_metadata.can_schedule(seq_group)):
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
num_new_tokens_uncached, num_new_tokens_cached = ( num_new_tokens_uncached, num_new_tokens_cached = (
self._get_num_new_uncached_and_cached_tokens( self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, enable_chunking, seq_group,
budget)) SequenceStatus.WAITING,
enable_chunking,
budget,
partial_prefill_metadata=partial_prefill_metadata,
))
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
if not enable_chunking: if not enable_chunking:
...@@ -948,7 +1090,10 @@ class Scheduler: ...@@ -948,7 +1090,10 @@ class Scheduler:
if num_new_tokens > prompt_limit: if num_new_tokens > prompt_limit:
logger.warning( logger.warning(
"Input prompt (%d tokens) is too long" "Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens, prompt_limit) " and exceeds limit of %d",
num_new_tokens,
prompt_limit,
)
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
...@@ -969,7 +1114,9 @@ class Scheduler: ...@@ -969,7 +1114,9 @@ class Scheduler:
logger.warning( logger.warning(
"Input prompt (%d tokens) + lookahead slots (%d) is " "Input prompt (%d tokens) + lookahead slots (%d) is "
"too long and exceeds the capacity of block_manager", "too long and exceeds the capacity of block_manager",
num_new_tokens, num_lookahead_slots) num_new_tokens,
num_lookahead_slots,
)
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
...@@ -1010,6 +1157,10 @@ class Scheduler: ...@@ -1010,6 +1157,10 @@ class Scheduler:
waiting_queue.popleft() waiting_queue.popleft()
self._allocate_and_set_running(seq_group) self._allocate_and_set_running(seq_group)
if partial_prefill_metadata is not None:
partial_prefill_metadata.maybe_increment_partial_prefills(
seq_group)
if enable_chunking and self.scheduler_config.is_multi_step: if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
# init_multi_step_from_lookahead_slots happens in append_slots # init_multi_step_from_lookahead_slots happens in append_slots
...@@ -1025,7 +1176,8 @@ class Scheduler: ...@@ -1025,7 +1176,8 @@ class Scheduler:
num_scheduler_steps=self.scheduler_config. num_scheduler_steps=self.scheduler_config.
num_scheduler_steps, num_scheduler_steps,
is_multi_step=self.scheduler_config.is_multi_step, is_multi_step=self.scheduler_config.is_multi_step,
enable_chunking=enable_chunking) enable_chunking=enable_chunking,
)
seq_groups.append( seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group, ScheduledSequenceGroup(seq_group=seq_group,
...@@ -1046,11 +1198,12 @@ class Scheduler: ...@@ -1046,11 +1198,12 @@ class Scheduler:
seq_groups=seq_groups, seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups, ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots( num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking)) is_prefill=True, enable_chunking=enable_chunking),
)
def _schedule_default(self) -> SchedulerOutputs: def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to optimize the throughput. First, The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can decodes. If there's a pressure on GPU memory, decode requests can
...@@ -1066,9 +1219,9 @@ class Scheduler: ...@@ -1066,9 +1219,9 @@ class Scheduler:
for seq_group in self.running: for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id, budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs()) seq_group.get_max_num_running_seqs())
curr_loras = set( curr_loras = (set(
seq_group.lora_int_id for seq_group in self.running seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None if seq_group.lora_int_id > 0) if self.lora_enabled else None)
prefills = SchedulerPrefillOutputs.create_empty() prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty() running_scheduled = SchedulerRunningOutputs.create_empty()
...@@ -1094,9 +1247,10 @@ class Scheduler: ...@@ -1094,9 +1247,10 @@ class Scheduler:
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len( if (len(running_scheduled.preempted) +
running_scheduled.swapped_out) == 0: len(running_scheduled.swapped_out) == 0):
swapped_in = self._schedule_swapped(budget, curr_loras) swapped_in = \
self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens) <= self.scheduler_config.max_num_batched_tokens)
...@@ -1116,8 +1270,8 @@ class Scheduler: ...@@ -1116,8 +1270,8 @@ class Scheduler:
# Update swapped requests. # Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) + preempted = len(running_scheduled.preempted) + len(
len(running_scheduled.swapped_out)) running_scheduled.swapped_out)
# There should be no prefill from running queue because this policy # There should be no prefill from running queue because this policy
# doesn't allow chunked prefills. # doesn't allow chunked prefills.
...@@ -1155,7 +1309,7 @@ class Scheduler: ...@@ -1155,7 +1309,7 @@ class Scheduler:
def _schedule_chunked_prefill(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs:
"""Schedule queued requests. """Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not as possible. 2. schedule chunked prefill requests that are not
...@@ -1176,10 +1330,20 @@ class Scheduler: ...@@ -1176,10 +1330,20 @@ class Scheduler:
prefills = SchedulerPrefillOutputs.create_empty() prefills = SchedulerPrefillOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty()
# Create partial prefill metadata
partial_prefill_metadata = PartialPrefillMetadata.from_queues(
running=self.running,
waiting=self.waiting,
scheduler_config=self.scheduler_config,
)
# Decoding should be always scheduled first by fcfs. # Decoding should be always scheduled first by fcfs.
running_scheduled = self._schedule_running(budget, running_scheduled = self._schedule_running(
curr_loras, budget,
enable_chunking=True) curr_loras,
enable_chunking=True,
partial_prefill_metadata=partial_prefill_metadata,
)
# Schedule swapped out requests. # Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in. # If preemption happens, it means we don't have space for swap-in.
...@@ -1187,9 +1351,12 @@ class Scheduler: ...@@ -1187,9 +1351,12 @@ class Scheduler:
running_scheduled.swapped_out) == 0: running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras) swapped_in = self._schedule_swapped(budget, curr_loras)
prefills = self._schedule_prefills(budget, prefills = self._schedule_prefills(
curr_loras, budget,
enable_chunking=True) curr_loras,
enable_chunking=True,
partial_prefill_metadata=partial_prefill_metadata,
)
assert (budget.num_batched_tokens assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens) <= self.scheduler_config.max_num_batched_tokens)
...@@ -1208,8 +1375,15 @@ class Scheduler: ...@@ -1208,8 +1375,15 @@ class Scheduler:
[s.seq_group for s in swapped_in.prefill_seq_groups]) [s.seq_group for s in swapped_in.prefill_seq_groups])
self.running.extend( self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups]) [s.seq_group for s in running_scheduled.decode_seq_groups])
# Because multiple prefills may be running concurrently, we need to
# make sure that prefills which are scheduled to finish are listed
# before those that won't. This is so that on the next scheduling
# iteration when they have transitioned to the decode stage, they are
# properly prioritized over sequences that are still in the prefill
# stage.
self.running.extend( self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups]) self._order_finishing_prefills_first(
running_scheduled.prefill_seq_groups))
self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend([s.seq_group for s in prefills.seq_groups])
# Update swapped requests. # Update swapped requests.
...@@ -1226,7 +1400,7 @@ class Scheduler: ...@@ -1226,7 +1400,7 @@ class Scheduler:
# If all prompts, then we set num_lookahead_slots to 0 # If all prompts, then we set num_lookahead_slots to 0
# this allows us to go through the `no_spec` path in # this allows us to go through the `no_spec` path in
# `spec_decode_worker.py` # `spec_decode_worker.py`
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) all_prefills = len(scheduled_seq_groups) == num_prefill_groups
num_lookahead_slots = (0 if num_lookahead_slots = (0 if
(all_prefills (all_prefills
and not self.scheduler_config.is_multi_step) and not self.scheduler_config.is_multi_step)
...@@ -1248,6 +1422,21 @@ class Scheduler: ...@@ -1248,6 +1422,21 @@ class Scheduler:
len(running_scheduled.swapped_out)), len(running_scheduled.swapped_out)),
) )
def _order_finishing_prefills_first(
self, scheduled_prefill_seqs: List[ScheduledSequenceGroup]
) -> List[SequenceGroup]:
"""Returns a list of prefilling SequenceGroups where sequences that are
scheduled to finish prefilling are listed first"""
finishing = [
s.seq_group for s in scheduled_prefill_seqs
if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size
]
not_finishing = [
s.seq_group for s in scheduled_prefill_seqs
if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size
]
return finishing + not_finishing
def _schedule(self) -> SchedulerOutputs: def _schedule(self) -> SchedulerOutputs:
"""Schedule queued requests.""" """Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled: if self.scheduler_config.chunked_prefill_enabled:
...@@ -1386,10 +1575,12 @@ class Scheduler: ...@@ -1386,10 +1575,12 @@ class Scheduler:
# between engine and worker. # between engine and worker.
# the subsequent comms can still use delta, but # the subsequent comms can still use delta, but
# `multi_modal_data` will be None. # `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data multi_modal_data=(seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups
multi_modal_placeholders=seq_group.multi_modal_placeholders > 0 else None),
if scheduler_outputs.num_prefill_groups > 0 else None, multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
) )
...@@ -1495,10 +1686,12 @@ class Scheduler: ...@@ -1495,10 +1686,12 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
def _append_slots(self, def _append_slots(
seq_group: SequenceGroup, self,
blocks_to_copy: List[Tuple[int, int]], seq_group: SequenceGroup,
enable_chunking: bool = False) -> None: blocks_to_copy: List[Tuple[int, int]],
enable_chunking: bool = False,
) -> None:
"""Appends new slots to the sequences in the given sequence group. """Appends new slots to the sequences in the given sequence group.
Args: Args:
...@@ -1519,7 +1712,8 @@ class Scheduler: ...@@ -1519,7 +1712,8 @@ class Scheduler:
num_lookahead_slots, num_lookahead_slots,
num_scheduler_steps=self.scheduler_config.num_scheduler_steps, num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
is_multi_step=self.scheduler_config.is_multi_step, is_multi_step=self.scheduler_config.is_multi_step,
enable_chunking=enable_chunking) enable_chunking=enable_chunking,
)
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
if self.scheduler_config.is_multi_step and enable_chunking: if self.scheduler_config.is_multi_step and enable_chunking:
...@@ -1562,8 +1756,11 @@ class Scheduler: ...@@ -1562,8 +1756,11 @@ class Scheduler:
"not enough KV cache space. This can affect the end-to-end " "not enough KV cache space. This can affect the end-to-end "
"performance. Increase gpu_memory_utilization or " "performance. Increase gpu_memory_utilization or "
"tensor_parallel_size to provide more KV cache memory. " "tensor_parallel_size to provide more KV cache memory. "
"total_num_cumulative_preemption=%d", seq_group.request_id, "total_num_cumulative_preemption=%d",
preemption_mode, self.num_cumulative_preemption + 1) seq_group.request_id,
preemption_mode,
self.num_cumulative_preemption + 1,
)
self.num_cumulative_preemption += 1 self.num_cumulative_preemption += 1
if preemption_mode == PreemptionMode.RECOMPUTE: if preemption_mode == PreemptionMode.RECOMPUTE:
...@@ -1669,6 +1866,7 @@ class Scheduler: ...@@ -1669,6 +1866,7 @@ class Scheduler:
status: SequenceStatus, status: SequenceStatus,
enable_chunking: bool, enable_chunking: bool,
budget: SchedulingBudget, budget: SchedulingBudget,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" """
Returns the number of new uncached and cached tokens to schedule for a Returns the number of new uncached and cached tokens to schedule for a
...@@ -1692,6 +1890,8 @@ class Scheduler: ...@@ -1692,6 +1890,8 @@ class Scheduler:
to schedule. to schedule.
enable_chunking: Whether to chunk the number of tokens to compute. enable_chunking: Whether to chunk the number of tokens to compute.
budget: The budget to chunk the number of tokens to compute. budget: The budget to chunk the number of tokens to compute.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns: Returns:
...@@ -1769,6 +1969,8 @@ class Scheduler: ...@@ -1769,6 +1969,8 @@ class Scheduler:
budget, budget,
self._get_prompt_limit(seq_group), self._get_prompt_limit(seq_group),
num_uncached_new_tokens, num_uncached_new_tokens,
self.partial_prefill_budget_lookup_list,
partial_prefill_metadata,
) )
return num_uncached_new_tokens, num_cached_new_tokens return num_uncached_new_tokens, num_cached_new_tokens
...@@ -1780,6 +1982,8 @@ class Scheduler: ...@@ -1780,6 +1982,8 @@ class Scheduler:
budget: SchedulingBudget, budget: SchedulingBudget,
prompt_limit: int, prompt_limit: int,
num_new_tokens: int, num_new_tokens: int,
partial_prefill_budget_lookup_list: List[int],
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> int: ) -> int:
""" """
Chunks the number of new tokens to schedule based on the budget when Chunks the number of new tokens to schedule based on the budget when
...@@ -1812,29 +2016,31 @@ class Scheduler: ...@@ -1812,29 +2016,31 @@ class Scheduler:
# the sequence. # the sequence.
return num_new_tokens return num_new_tokens
return (0 if num_new_tokens > remaining_token_budget else return 0 if num_new_tokens > \
num_new_tokens) remaining_token_budget else num_new_tokens
if cache_config.enable_prefix_caching: # Get the number of tokens to allocate to this prefill slot
# Adjust the remaining token budget to be divisible by the block prefill_slot_budget = (
# size when prefix caching is enabled. remaining_token_budget if partial_prefill_metadata is None else
partial_prefill_budget_lookup_list[
partial_prefill_metadata.schedulable_prefills])
# When prefix caching is enabled, we always allocate if cache_config.enable_prefix_caching:
# the number of new tokens that is dividable by the block # When prefix caching is enabled and we're partially prefilling
# size to avoid partial block matching. # a sequence, we always allocate a number of new tokens that is
# divisible by the block size to avoid partial block matching.
block_size = cache_config.block_size block_size = cache_config.block_size
remainder = budget.token_budget % block_size # Don't exceed either the total budget or slot budget.
if remainder != 0: # Take min of those and get the next lowest multiple of the
raise ValueError("When enabling chunked prefill and " # block size:
"prefix caching, max_num_batched_tokens " remaining_token_budget = (
"(chunk size) must be dividable by " min(remaining_token_budget, prefill_slot_budget) //
"block size, but got chunk_size " block_size) * block_size
f"({budget.token_budget}) % block_size " # NB: In the case where num_new_tokens < budget, we are
f"({block_size}) = {remainder}") # finishing prefill for this sequence, so we do not need to
# Round down to block size. # allocate a full block.
remaining_token_budget = (remaining_token_budget // block_size *
block_size) num_new_tokens = min(num_new_tokens, remaining_token_budget,
prefill_slot_budget)
num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens return num_new_tokens
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
# the only successful approach is to call cuda driver API in C. # the only successful approach is to call cuda driver API in C.
import dataclasses import dataclasses
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
...@@ -97,7 +97,7 @@ def use_memory_pool_with_allocator( ...@@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool): with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool yield mem_pool, new_alloc
class CuMemAllocator: class CuMemAllocator:
...@@ -142,6 +142,7 @@ class CuMemAllocator: ...@@ -142,6 +142,7 @@ class CuMemAllocator:
def __init__(self): def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {} self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None: def python_malloc_callback(self, allocation_handle: HandleType) -> None:
""" """
...@@ -231,7 +232,13 @@ class CuMemAllocator: ...@@ -231,7 +232,13 @@ class CuMemAllocator:
old_tag = self.current_tag old_tag = self.current_tag
self.current_tag = tag self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback, with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback): self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error # PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see # when using pluggable allocator, see
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
class CpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
self.ipex_available = False
self.dist_module = torch.distributed
try:
import intel_extension_for_pytorch as ipex
self.ipex_available = True
self.dist_module = ipex.distributed
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU (e.g. MacOS)
"""
pass
def all_reduce(self, input_):
return self.dist_module.all_reduce(input_, group=self.device_group)
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
class CudaCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if "pp" in unique_name:
# pipeline parallel does not need custom allreduce
use_custom_allreduce = False
else:
from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_pynccl = True
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
...@@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional ...@@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes` # this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa import torch # noqa
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -105,8 +106,13 @@ class CudaRTLibrary: ...@@ -105,8 +106,13 @@ class CudaRTLibrary:
def __init__(self, so_file: Optional[str] = None): def __init__(self, so_file: Optional[str] = None):
if so_file is None: if so_file is None:
so_file = find_loaded_library("libcudart") so_file = find_loaded_library("libcudart")
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
assert so_file is not None, \ assert so_file is not None, \
"libcudart is not loaded in the current process" (
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
)
if so_file not in CudaRTLibrary.path_to_library_cache: if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file) lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib CudaRTLibrary.path_to_library_cache[so_file] = lib
......
...@@ -2,45 +2,40 @@ ...@@ -2,45 +2,40 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
if current_platform.is_hpu(): if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401 import habana_frameworks.torch as htorch # noqa: F401
class HpuCommunicator: class HpuCommunicator(DeviceCommunicatorBase):
def __init__(self, group: ProcessGroup):
if not current_platform.is_hpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
def all_reduce(self, x: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference) # (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step() htorch.core.mark_step()
dist.all_reduce(x, group=self.group) dist.all_reduce(input_, group=self.device_group)
return x return input_
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size world_size = self.world_size
if dim < 0: if dim < 0:
# Convert negative dim to positive. # Convert negative dim to positive.
dim += x.dim() dim += input_.dim()
input_size = x.size() input_size = input_.size()
# Allocate output tensor. # Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size, output_tensor = torch.empty((world_size, ) + input_size,
dtype=x.dtype, dtype=input_.dtype,
device=x.device) device=input_.device)
# All-gather. # All-gather.
htorch.core.mark_step() htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group) dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape # Reshape
output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] + output_tensor = output_tensor.reshape(input_size[:dim] +
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from typing import Optional
import torch import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
if current_platform.is_tpu(): if current_platform.is_tpu():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
...@@ -16,19 +18,20 @@ if current_platform.is_tpu(): ...@@ -16,19 +18,20 @@ if current_platform.is_tpu():
from vllm.executor import ray_utils from vllm.executor import ray_utils
class TpuCommunicator: class TpuCommunicator(DeviceCommunicatorBase):
def __init__(self, group: ProcessGroup): def __init__(self,
if not current_platform.is_tpu(): cpu_group: ProcessGroup,
self.disabled = True device: Optional[torch.device] = None,
return device_group: Optional[ProcessGroup] = None,
self.disabled = False unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can # must be used together. Therefore, the local rank and world size can
# be simply calculated as follows. # be simply calculated as follows.
global_rank = dist.get_rank(group) global_rank = self.global_rank
global_world_size = dist.get_world_size(group) global_world_size = self.global_world_size
# Calculate how many TPU nodes are in the current deployment. This # Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default # is the Ray placement group if it is deployed with Ray. Default
...@@ -55,9 +58,9 @@ class TpuCommunicator: ...@@ -55,9 +58,9 @@ class TpuCommunicator:
pjrt.initialize_multiprocess(local_rank, local_world_size) pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal() xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x) return xm.all_reduce(xm.REDUCE_SUM, input_)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather." assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim) return xm.all_gather(input_, dim=dim)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase
class XpuCommunicator: class XpuCommunicator(DeviceCommunicatorBase):
def __init__(self, group: ProcessGroup): def __init__(self,
if not current_platform.is_xpu(): cpu_group: ProcessGroup,
self.disabled = True device: Optional[torch.device] = None,
return device_group: Optional[ProcessGroup] = None,
self.disabled = False unique_name: str = ""):
self.group = group super().__init__(cpu_group, device, device_group, unique_name)
self.world_size = dist.get_world_size(self.group)
def all_reduce(self, x: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(x, group=self.group) dist.all_reduce(input_, group=self.device_group)
return x return input_
def gather(self, def gather(self,
input_: torch.Tensor, input_: torch.Tensor,
rank_in_group: int,
dst: int = 0, dst: int = 0,
dim: int = -1): dim: int = -1) -> Optional[torch.Tensor]:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# For xpu path, gather doesn't work properly together with ray # For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now. # cluster so we use all_gather instead for now.
input_size = input_.size() input_size = input_.size()
...@@ -34,10 +39,10 @@ class XpuCommunicator: ...@@ -34,10 +39,10 @@ class XpuCommunicator:
dtype=input_.dtype, dtype=input_.dtype,
device=input_.device) device=input_.device)
# All-gather. # All-gather.
torch.distributed.all_gather_into_tensor(output_tensor, dist.all_gather_into_tensor(output_tensor,
input_, input_,
group=self.group) group=self.device_group)
if rank_in_group == dst: if self.rank_in_group == dst:
# Reshape # Reshape
output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] + output_tensor = output_tensor.reshape(input_size[:dim] +
......
...@@ -14,8 +14,8 @@ The KV cache transfer contains three layer of abstractions: ...@@ -14,8 +14,8 @@ The KV cache transfer contains three layer of abstractions:
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or communication service already supports key-value-based lookup (like redis or
RDMA database). RDMA database).
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
...@@ -27,4 +27,3 @@ The example usage is in [this file](../../../examples/online_serving/disaggregat ...@@ -27,4 +27,3 @@ The example usage is in [this file](../../../examples/online_serving/disaggregat
Here is the diagram of how we run disaggretgated prefilling. Here is the diagram of how we run disaggretgated prefilling.
![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) ![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg)
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
stop the prefill instance when the decode instance is slow. stop the prefill instance when the decode instance is slow.
""" """
import threading import threading
import time
from collections import deque from collections import deque
from typing import Deque, List, Optional, Union from typing import Deque, List, Optional, Union
...@@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float): buffer_size_thresh: float):
""" """
signal_pipe: on CPU signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request. CPU recv to listen to new request.
data_pipe: on device (e.g. GPU) data_pipe: on device (e.g. GPU)
""" """
...@@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase):
self.buffer_size = 0 self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock() self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe self.signal_pipe = signal_pipe
self.data_pipe = data_pipe self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None self.request_handling_thread: Optional[threading.Thread] = None
...@@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase):
hidden = hidden.clone() hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden] buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()
with self.buffer_lock: self.buffer_size += data_size
for data in buffer_item:
self.buffer_size += self._get_element_size(data)
self.buffer.append(buffer_item) self.buffer.append(buffer_item)
self.buffer_cv.notify()
def _is_end_signal(self, signal): def _is_end_signal(self, signal):
return signal is None return signal is None
...@@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase):
roi = (roi > 0.5) roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi] tokens_roi_recver = [input_tokens, roi]
matched_length = 0 def is_buffer_available(
tokens_roi_recver: List[torch.Tensor], ) -> bool:
# perform input tokens and roi matching # perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1) # FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so # but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent. # the fix is not urgent.
with self.buffer_lock:
for _ in range(len(self.buffer)): for _ in range(len(self.buffer)):
if self._matches(self.buffer[0],
temp_length = self._matches(self.buffer[0], tokens_roi_recver) > 0:
tokens_roi_recver) return True
if temp_length > 0:
matched_length = temp_length
break
# rotate the element we just accessed to the end # rotate the element we just accessed to the end
self.buffer.rotate(-1) self.buffer.rotate(-1)
return False
if matched_length > 0:
# need to clone the tensor with self.buffer_cv:
# in case the tensor is freed before sending finishes while not is_buffer_available(tokens_roi_recver):
matched_item = self.buffer.popleft() logger.debug(
for tensor in matched_item: "KV transfer buffer is not available. Waiting...")
self._send_tensor_and_dec_size(tensor) self.buffer_cv.wait()
# need to clone the tensor
else: # in case the tensor is freed before sending finishes
# no match, just send None matched_item = self.buffer.popleft()
for _ in range(5): for tensor in matched_item:
self.data_pipe.send_tensor(None) self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()
except RuntimeError as e: except RuntimeError as e:
if 'Connection closed by peer' not in str(e): if 'Connection closed by peer' not in str(e):
...@@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase): ...@@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase):
return [input_tokens, roi, key, value, hidden] return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None: hidden: torch.Tensor) -> None:
if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
self._add_to_buffer(input_tokens, roi, key, value, hidden) self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender # when calling the insert, the current process is a sender
......
...@@ -39,9 +39,12 @@ from torch.distributed import Backend, ProcessGroup ...@@ -39,9 +39,12 @@ from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import direct_register_custom_op, supports_custom_op from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -130,9 +133,8 @@ class GroupCoordinator: ...@@ -130,9 +133,8 @@ class GroupCoordinator:
PyTorch ProcessGroup is bound to one specific communication backend, PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc. e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to the processes in the group. It manages both CPU and device
a specific implementation (e.g. switch allreduce implementation communication.
based on the tensor size and cuda graph mode).
""" """
# available attributes: # available attributes:
...@@ -150,11 +152,8 @@ class GroupCoordinator: ...@@ -150,11 +152,8 @@ class GroupCoordinator:
rank_in_group: int # rank inside the group rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl use_device_communicator: bool # whether to use device communicator
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce device_communicator: DeviceCommunicatorBase # device communicator
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__( def __init__(
...@@ -162,11 +161,7 @@ class GroupCoordinator: ...@@ -162,11 +161,7 @@ class GroupCoordinator:
group_ranks: List[List[int]], group_ranks: List[List[int]],
local_rank: int, local_rank: int,
torch_distributed_backend: Union[str, Backend], torch_distributed_backend: Union[str, Backend],
use_pynccl: bool, use_device_communicator: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
): ):
...@@ -196,56 +191,26 @@ class GroupCoordinator: ...@@ -196,56 +191,26 @@ class GroupCoordinator:
assert self.device_group is not None assert self.device_group is not None
from vllm.platforms import current_platform from vllm.platforms import current_platform
# TODO: fix it for other platforms
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}") self.device = torch.device(f"cuda:{local_rank}")
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.use_pynccl = use_pynccl self.use_device_communicator = use_device_communicator
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
self.ca_comm: Optional[CustomAllreduce] = None self.device_communicator: DeviceCommunicatorBase = None # type: ignore
if use_custom_allreduce and self.world_size > 1: if use_device_communicator and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. device_comm_cls = resolve_obj_by_qualname(
self.ca_comm = CustomAllreduce( current_platform.get_device_communicator_cls())
group=self.cpu_group, self.device_communicator = device_comm_cls(
cpu_group=self.cpu_group,
device=self.device, device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
) )
from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator] = None
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
from vllm.distributed.device_communicators.hpu_communicator import (
HpuCommunicator)
self.hpu_communicator: Optional[HpuCommunicator]
if use_hpu_communicator and self.world_size > 1:
self.hpu_communicator = HpuCommunicator(group=self.device_group)
from vllm.distributed.device_communicators.xpu_communicator import (
XpuCommunicator)
self.xpu_communicator: Optional[XpuCommunicator]
if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group)
from vllm.distributed.device_communicators.shm_broadcast import ( from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue) MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None self.mq_broadcaster: Optional[MessageQueue] = None
...@@ -253,6 +218,9 @@ class GroupCoordinator: ...@@ -253,6 +218,9 @@ class GroupCoordinator:
self.mq_broadcaster = MessageQueue.create_from_process_group( self.mq_broadcaster = MessageQueue.create_from_process_group(
self.cpu_group, 1 << 22, 6) self.cpu_group, 1 << 22, 6)
from vllm.platforms import current_platform
self.use_custom_op_call = current_platform.is_cuda_alike()
@property @property
def first_rank(self): def first_rank(self):
"""Return the global rank of the first process in the group""" """Return the global rank of the first process in the group"""
...@@ -296,9 +264,16 @@ class GroupCoordinator: ...@@ -296,9 +264,16 @@ class GroupCoordinator:
else: else:
stream = graph_capture_context.stream stream = graph_capture_context.stream
ca_comm = self.ca_comm # only cuda uses this function,
maybe_ca_context = nullcontext( # so we don't abstract it into the base class
) if ca_comm is None else ca_comm.capture() maybe_ca_context = nullcontext()
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
if self.device_communicator is not None:
assert isinstance(self.device_communicator, CudaCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to # ensure all initialization operations complete before attempting to
# capture the graph on another stream # capture the graph on another stream
...@@ -328,54 +303,14 @@ class GroupCoordinator: ...@@ -328,54 +303,14 @@ class GroupCoordinator:
if self.world_size == 1: if self.world_size == 1:
return input_ return input_
if input_.is_cpu: if self.use_custom_op_call:
try: return torch.ops.vllm.all_reduce(input_,
import intel_extension_for_pytorch as ipex group_name=self.unique_name)
ipex.distributed.all_reduce(input_, group=self.device_group) else:
return input_ return self._all_reduce_out_place(input_)
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU
"""
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self.tpu_communicator.all_reduce(input_)
if self.hpu_communicator is not None and \
not self.hpu_communicator.disabled:
return self.hpu_communicator.all_reduce(input_)
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
# always try custom allreduce first, return self.device_communicator.all_reduce(input_)
# and then pynccl.
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size world_size = self.world_size
...@@ -385,40 +320,7 @@ class GroupCoordinator: ...@@ -385,40 +320,7 @@ class GroupCoordinator:
assert -input_.dim() <= dim < input_.dim(), ( assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator. return self.device_communicator.all_gather(input_, dim)
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def gather(self, def gather(self,
input_: torch.Tensor, input_: torch.Tensor,
...@@ -433,30 +335,7 @@ class GroupCoordinator: ...@@ -433,30 +335,7 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_ return input_
assert -input_.dim() <= dim < input_.dim(), ( return self.device_communicator.gather(input_, dst, dim)
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.gather(input_, self.rank_in_group,
dst, dim)
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(self, input_: torch.Tensor, src: int = 0): def broadcast(self, input_: torch.Tensor, src: int = 0):
"""Broadcast the input tensor. """Broadcast the input tensor.
...@@ -798,14 +677,7 @@ class GroupCoordinator: ...@@ -798,14 +677,7 @@ class GroupCoordinator:
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way""" """Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank.""" """NOTE: `dst` is the local rank of the destination rank."""
if dst is None: self.device_communicator.send(tensor, dst)
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self, def recv(self,
size: torch.Size, size: torch.Size,
...@@ -813,16 +685,7 @@ class GroupCoordinator: ...@@ -813,16 +685,7 @@ class GroupCoordinator:
src: Optional[int] = None) -> torch.Tensor: src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank.""" """Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank.""" """NOTE: `src` is the local rank of the source rank."""
if src is None: return self.device_communicator.recv(size, dtype, src)
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self): def destroy(self):
if self.device_group is not None: if self.device_group is not None:
...@@ -831,10 +694,8 @@ class GroupCoordinator: ...@@ -831,10 +694,8 @@ class GroupCoordinator:
if self.cpu_group is not None: if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group) torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None self.cpu_group = None
if self.pynccl_comm is not None: if self.device_communicator is not None:
self.pynccl_comm = None self.device_communicator.destroy()
if self.ca_comm is not None:
self.ca_comm = None
if self.mq_broadcaster is not None: if self.mq_broadcaster is not None:
self.mq_broadcaster = None self.mq_broadcaster = None
...@@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int, ...@@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int,
group_ranks=[ranks], group_ranks=[ranks],
local_rank=local_rank, local_rank=local_rank,
torch_distributed_backend=backend, torch_distributed_backend=backend,
use_pynccl=False, use_device_communicator=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
group_name="world", group_name="world",
) )
...@@ -866,23 +723,15 @@ def init_model_parallel_group( ...@@ -866,23 +723,15 @@ def init_model_parallel_group(
group_ranks: List[List[int]], group_ranks: List[List[int]],
local_rank: int, local_rank: int,
backend: str, backend: str,
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
) -> GroupCoordinator: ) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
from vllm.platforms import current_platform
return GroupCoordinator( return GroupCoordinator(
group_ranks=group_ranks, group_ranks=group_ranks,
local_rank=local_rank, local_rank=local_rank,
torch_distributed_backend=backend, torch_distributed_backend=backend,
use_pynccl=current_platform.is_cuda_alike(), use_device_communicator=True,
use_custom_allreduce=current_platform.is_cuda_alike()
and use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster, use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name, group_name=group_name,
) )
...@@ -1024,13 +873,6 @@ def initialize_model_parallel( ...@@ -1024,13 +873,6 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend( backend = backend or torch.distributed.get_backend(
get_world_group().device_group) get_world_group().device_group)
if (world_size
!= tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size // num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size) tensor_model_parallel_size)
...@@ -1060,11 +902,9 @@ def initialize_model_parallel( ...@@ -1060,11 +902,9 @@ def initialize_model_parallel(
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks) group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, _PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_custom_allreduce=False,
group_name="pp") group_name="pp")
......
...@@ -20,6 +20,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, ...@@ -20,6 +20,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.plugins import load_general_plugins
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean from vllm.utils import FlexibleArgumentParser, StoreBoolean
...@@ -119,6 +120,9 @@ class EngineArgs: ...@@ -119,6 +120,9 @@ class EngineArgs:
cpu_offload_gb: float = 0 # GiB cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_partial_prefills: Optional[int] = 1
max_long_partial_prefills: Optional[int] = 1
long_prefill_token_threshold: Optional[int] = 0
max_num_seqs: Optional[int] = None max_num_seqs: Optional[int] = None
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False disable_log_stats: bool = False
...@@ -191,6 +195,7 @@ class EngineArgs: ...@@ -191,6 +195,7 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None override_pooler_config: Optional[PoolerConfig] = None
...@@ -206,6 +211,7 @@ class EngineArgs: ...@@ -206,6 +211,7 @@ class EngineArgs:
calculate_kv_scales: Optional[bool] = None calculate_kv_scales: Optional[bool] = None
additional_config: Optional[Dict[str, Any]] = None
moe_ep_size: int = 1 moe_ep_size: int = 1
def __post_init__(self): def __post_init__(self):
...@@ -286,11 +292,13 @@ class EngineArgs: ...@@ -286,11 +292,13 @@ class EngineArgs:
'--tokenizer-mode', '--tokenizer-mode',
type=str, type=str,
default=EngineArgs.tokenizer_mode, default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow', 'mistral'], choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the ' help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will ' 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* ' 'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.') '"mistral" will always use the `mistral_common` tokenizer. \n* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.')
parser.add_argument('--trust-remote-code', parser.add_argument('--trust-remote-code',
action='store_true', action='store_true',
help='Trust remote code from huggingface.') help='Trust remote code from huggingface.')
...@@ -520,6 +528,31 @@ class EngineArgs: ...@@ -520,6 +528,31 @@ class EngineArgs:
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per ' help='Maximum number of batched tokens per '
'iteration.') 'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills."
"Defaults to 1",
)
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency. Defaults to 1.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens. Defaults to 4%% of "
"the model's context length.",
)
parser.add_argument('--max-num-seqs', parser.add_argument('--max-num-seqs',
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
...@@ -929,6 +962,13 @@ class EngineArgs: ...@@ -929,6 +962,13 @@ class EngineArgs:
'priority (lower value means earlier handling) and time of ' 'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).') 'arrival deciding any ties).')
parser.add_argument(
'--scheduler-cls',
default=EngineArgs.scheduler_cls,
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
'is the default scheduler. Can be a class directly or the path to '
'a class of form "mod.custom_class".')
parser.add_argument( parser.add_argument(
'--override-neuron-config', '--override-neuron-config',
type=json.loads, type=json.loads,
...@@ -1008,6 +1048,14 @@ class EngineArgs: ...@@ -1008,6 +1048,14 @@ class EngineArgs:
'be loaded from the model checkpoint if available. ' 'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.') 'Otherwise, the scales will default to 1.0.')
parser.add_argument(
"--additional-config",
type=json.loads,
default=None,
help="Additional config for specified platform in JSON format. "
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{\"config_key\":\"config_value\"}'")
return parser return parser
@classmethod @classmethod
...@@ -1068,6 +1116,9 @@ class EngineArgs: ...@@ -1068,6 +1116,9 @@ class EngineArgs:
def create_engine_config(self, def create_engine_config(self,
usage_context: Optional[UsageContext] = None usage_context: Optional[UsageContext] = None
) -> VllmConfig: ) -> VllmConfig:
from vllm.platforms import current_platform
current_platform.pre_register_and_update()
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
self._override_v1_engine_args(usage_context) self._override_v1_engine_args(usage_context)
...@@ -1254,7 +1305,13 @@ class EngineArgs: ...@@ -1254,7 +1305,13 @@ class EngineArgs:
multi_step_stream_outputs=self.multi_step_stream_outputs, multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray), and parallel_config.use_ray),
policy=self.scheduling_policy) policy=self.scheduling_policy,
scheduler_cls=self.scheduler_cls,
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
)
lora_config = LoRAConfig( lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias, bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
...@@ -1315,6 +1372,7 @@ class EngineArgs: ...@@ -1315,6 +1372,7 @@ class EngineArgs:
prompt_adapter_config=prompt_adapter_config, prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config, compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config, kv_transfer_config=self.kv_transfer_config,
additional_config=self.additional_config,
) )
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
...@@ -1375,6 +1433,12 @@ class AsyncEngineArgs(EngineArgs): ...@@ -1375,6 +1433,12 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='Disable logging requests.') help='Disable logging requests.')
# Initialize plugin to update the parser, for example, The plugin may
# adding a new kind of quantization method to --quantization argument or
# a new device to --device argument.
load_general_plugins()
from vllm.platforms import current_platform
current_platform.pre_register_and_update(parser)
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment