Unverified Commit 4ac8e09d authored by Yuwei An's avatar Yuwei An Committed by GitHub
Browse files

Piecewise CUDA Graph Support & Torch Compile Backend (#10062)


Signed-off-by: default avatarOasis-Git <ayw.sirius19@gmail.com>
parent 20a6c0a6
......@@ -185,7 +185,7 @@ class CustomAllreduce:
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
max_size, dtype=torch.uint8, device=self.device
)
self._ptr = ops.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
......@@ -202,7 +202,7 @@ class CustomAllreduce:
)
handles, offsets = self._gather_ipc_meta(shard_data)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
max_size, dtype=torch.uint8, device=self.device
)
self._ptr = ops.init_custom_ar(
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
......
......@@ -239,6 +239,7 @@ class GroupCoordinator:
use_npu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
torch_compile: Optional[bool] = None,
):
# Set group info
group_name = group_name or "anonymous"
......@@ -326,10 +327,18 @@ class GroupCoordinator:
self.qr_comm: Optional[QuickAllReduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
if torch_compile is not None and torch_compile:
# For piecewise CUDA graph, the requirement for custom allreduce is larger to
# avoid illegal cuda memory access.
ca_max_size = 256 * 1024 * 1024
else:
ca_max_size = 8 * 1024 * 1024
try:
# print(f"ca_max_size: {ca_max_size}")
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
max_size=ca_max_size,
)
except Exception as e:
logger.warning(
......@@ -1260,6 +1269,7 @@ def init_model_parallel_group(
group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None,
use_symm_mem_allreduce: Optional[bool] = None,
torch_compile: Optional[bool] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
......@@ -1280,6 +1290,7 @@ def init_model_parallel_group(
use_npu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
torch_compile=torch_compile,
)
......@@ -1439,6 +1450,7 @@ def initialize_model_parallel(
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
duplicate_tp_group: bool = False,
torch_compile: Optional[bool] = None,
) -> None:
"""
Initialize model parallel groups.
......@@ -1494,6 +1506,7 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="tp",
torch_compile=torch_compile,
)
if duplicate_tp_group:
......@@ -1509,6 +1522,7 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="pdmux_prefill_tp",
torch_compile=torch_compile,
)
_TP.pynccl_comm.disabled = False
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
......@@ -1518,7 +1532,6 @@ def initialize_model_parallel(
global _MOE_EP
assert _MOE_EP is None, "expert model parallel group is already initialized"
if moe_ep_size == tensor_model_parallel_size:
_MOE_EP = _TP
else:
......@@ -1539,7 +1552,6 @@ def initialize_model_parallel(
global _MOE_TP
assert _MOE_TP is None, "expert model parallel group is already initialized"
if moe_tp_size == tensor_model_parallel_size:
_MOE_TP = _TP
else:
......
......@@ -43,11 +43,16 @@ _is_cpu = is_cpu()
_is_xpu = is_xpu()
if _is_cuda:
if _is_flashinfer_available:
from flashinfer.norm import fused_add_rmsnorm
else:
from sgl_kernel import fused_add_rmsnorm
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
# if _is_flashinfer_available:
# from flashinfer.norm import fused_add_rmsnorm
# else:
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm
......
......@@ -17,12 +17,18 @@ from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Optional
import torch
from torch import nn
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.compilation.piecewise_context_manager import (
get_forward_context,
)
from sglang.srt.utils import direct_register_custom_op
class AttentionType(Enum):
"""
......@@ -105,12 +111,58 @@ class RadixAttention(nn.Module):
else:
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(
q,
k,
v,
self,
forward_batch,
save_kv_cache,
**kwargs,
)
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
output = torch.zeros_like(q)
torch.ops.sglang.unified_attention_with_output(
q, k, v, output, save_kv_cache, self.layer_id
)
return output
else:
return forward_batch.attn_backend.forward(
q,
k,
v,
self,
forward_batch,
save_kv_cache,
**kwargs,
)
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
save_kv_cache: bool,
layer_id: int,
) -> None:
context = get_forward_context()
forward_batch = context.forward_batch
attention_layers = context.attention_layers
attention_layer = attention_layers[layer_id]
ret = forward_batch.attn_backend.forward(
query, key, value, attention_layer, forward_batch, save_kv_cache
)
assert output.shape == ret.shape
output.copy_(ret)
return
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
save_kv_cache: bool,
layer_id: int,
) -> None:
return
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
)
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
import ast
import dataclasses
import logging
import os
import pprint
import time
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Callable, Optional
import torch
import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
from sglang.srt.model_executor.compilation.compilation_counter import (
compilation_counter,
)
from sglang.srt.model_executor.compilation.compiler_interface import InductorAdaptor
from sglang.srt.model_executor.compilation.cuda_piecewise_backend import (
CUDAPiecewiseBackend,
)
from sglang.srt.model_executor.compilation.pass_manager import PostGradPassManager
logger = logging.getLogger(__name__)
def make_compiler():
return InductorAdaptor()
class CompilerManager:
def __init__(
self,
):
self.cache = dict()
self.is_cache_updated = False
self.compiler = make_compiler()
def compute_hash(self):
return self.compiler.compute_hash()
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
self.disable_cache = disable_cache
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py")
if not disable_cache and os.path.exists(self.cache_file_path):
with open(self.cache_file_path) as f:
self.cache = ast.literal_eval(f.read())
self.compiler.initialize_cache(
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
)
def save_to_file(self):
if self.disable_cache or not self.is_cache_updated:
return
printer = pprint.PrettyPrinter(indent=4)
data = printer.pformat(self.cache)
with open(self.cache_file_path, "w") as f:
f.write(data)
def load(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None,
) -> Optional[Callable]:
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, runtime_shape
)
if runtime_shape is None:
logger.debug(
"Directly load the %s-th graph for dynamic shape from %s via "
"handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via " "handle %s",
graph_index,
str(runtime_shape),
self.compiler.name,
handle,
)
return compiled_graph
def compile(
self,
graph: fx.GraphModule,
example_inputs,
inductor_config: dict[str, Any],
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None,
) -> Any:
if graph_index == 0:
# before compiling the first graph, record the start time
global compilation_start_time
compilation_start_time = time.time()
compilation_counter.num_backend_compilations += 1
compiled_graph = None
# TODO(Yuwei): support cache loading
# no compiler cached the graph, or the cache is disabled,
# we need to compile it
if isinstance(self.compiler, InductorAdaptor):
maybe_key = None
else:
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
compiled_graph, handle = self.compiler.compile(
graph, example_inputs, inductor_config, runtime_shape, maybe_key
)
assert compiled_graph is not None, "Failed to compile the graph"
# store the artifact in the cache
if handle is not None:
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True
if graph_index == 0:
# adds some info logging for the first graph
if runtime_shape is None:
logger.info("Cache the graph for dynamic shape for later use")
else:
logger.info(
"Cache the graph of shape %s for later use", str(runtime_shape)
)
if runtime_shape is None:
logger.debug(
"Store the %s-th graph for dynamic shape from %s via " "handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Store the %s-th graph for shape %s from %s via handle %s",
graph_index,
str(runtime_shape),
self.compiler.name,
handle,
)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
elapsed = now - compilation_start_time
if runtime_shape is None:
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
else:
logger.info(
"Compiling a graph for shape %s takes %.2f s",
runtime_shape,
elapsed,
)
return compiled_graph
@dataclasses.dataclass
class SplitItem:
submod_name: str
graph_id: int
is_splitting_graph: bool
graph: fx.GraphModule
def split_graph(
graph: fx.GraphModule, ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops
subgraph_id = 0
node_to_subgraph_id = {}
split_op_graphs = []
for node in graph.graph.nodes:
if node.op in ("output", "placeholder"):
continue
if node.op == "call_function" and str(node.target) in ops:
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id)
subgraph_id += 1
else:
node_to_subgraph_id[node] = subgraph_id
# `keep_original_order` is important!
# otherwise pytorch might reorder the nodes and
# the semantics of the graph will change when we
# have mutations in the graph
split_gm = torch.fx.passes.split_module.split_module(
graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
)
outputs = []
names = [name for (name, module) in split_gm.named_modules()]
for name in names:
if "." in name or name == "":
# recursive child module or the root module
continue
module = getattr(split_gm, name)
graph_id = int(name.replace("submod_", ""))
outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
# sort by intetger graph_id, rather than string name
outputs.sort(key=lambda x: x.graph_id)
return split_gm, outputs
# we share the global graph pool among all the backends
global_graph_pool = None
compilation_start_time = 0.0
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def __init__(
self,
module: torch.fx.GraphModule,
compile_submod_names: list[str],
inductor_config: dict[str, Any],
graph_pool,
compile_config: CompilationConfig,
sglang_backend: "SGLangBackend",
):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.graph_pool = graph_pool
self.sglang_backend = sglang_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False
self.inductor_config = inductor_config
self.compile_config = compile_config
def run(self, *args):
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
with self.fake_mode, enable_python_dispatcher():
return super().run(*fake_args)
def call_module(
self,
target: torch.fx.node.Target,
args: tuple[torch.fx.node.Argument, ...],
kwargs: dict[str, Any],
) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)
if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time
compiled_graph_for_dynamic_shape = (
self.sglang_backend.compiler_manager.compile(
submod,
args,
self.inductor_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
)
)
self.module.__dict__[target] = CUDAPiecewiseBackend(
submod,
self.compile_config,
self.inductor_config,
self.graph_pool,
index,
len(self.compile_submod_names),
sym_shape_indices,
compiled_graph_for_dynamic_shape,
self.sglang_backend,
)
compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output
model_tag: str = "backbone"
@contextmanager
def set_model_tag(tag: str):
"""Context manager to set the model tag."""
global model_tag
assert (
tag != model_tag
), f"Model tag {tag} is the same as the current tag {model_tag}."
old_tag = model_tag
model_tag = tag
try:
yield
finally:
model_tag = old_tag
class SGLangBackend:
graph_pool: Any
_called: bool = False
# the graph we compiled
graph: fx.GraphModule
# the stiching graph module for all the piecewise graphs
split_gm: fx.GraphModule
piecewise_graphs: list[SplitItem]
returned_callable: Callable
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable]
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager
def __init__(
self,
config: CompilationConfig,
graph_pool: Any,
):
assert graph_pool is not None
self.graph_pool = graph_pool
self.post_grad_pass_manager = PostGradPassManager()
self.sym_tensor_indices = []
self.input_buffers = []
self.compiler_manager = CompilerManager()
self.inductor_config = {
"enable_auto_functionalized_v2": False,
}
self.compile_config = config
def configure_post_pass(self):
self.post_grad_pass_manager.configure()
self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
base_cache_dir = os.path.expanduser(
os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/")
)
cache_hash = self.compiler_manager.compute_hash()
cache_dir = os.path.join(
base_cache_dir,
"torch_compile_cache",
cache_hash,
)
os.makedirs(cache_dir, exist_ok=True)
rank = 0
dp_rank = 0
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag)
os.makedirs(local_cache_dir, exist_ok=True)
self.compiler_manager.initialize_cache(
local_cache_dir, disable_cache=False, prefix=""
)
compilation_counter.num_graphs_seen += 1
assert not self._called, "SGLangBackend can only be called once"
self.graph = graph
self.configure_post_pass()
self.split_gm, self.piecewise_graphs = split_graph(
graph, ["sglang.unified_attention_with_output"]
)
from torch._dynamo.utils import lazy_format_graph_code
# depyf will hook lazy_format_graph_code and dump the graph
# for debugging, no need to print the graph here
lazy_format_graph_code("before split", self.graph)
lazy_format_graph_code("after split", self.split_gm)
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
submod_names_to_compile = [
item.submod_name
for item in self.piecewise_graphs
if not item.is_splitting_graph
]
PiecewiseCompileInterpreter(
self.split_gm,
submod_names_to_compile,
self.inductor_config,
self.graph_pool,
self.compile_config,
self,
).run(*example_inputs)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
# use `print_readable` because it can include submodules
src = (
"from __future__ import annotations\nimport torch\n"
+ self.split_gm.print_readable(print_output=False)
)
src = src.replace("<lambda>", "GraphModule")
with open(graph_path, "w") as f:
f.write(src)
logger.debug("Computation graph saved to %s", graph_path)
self._called = True
return self.split_gm
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py
from typing import List
# TODO(Yuwei): support better compile config support
class CompilationConfig:
def __init__(self, capture_sizes: List[int]):
self.traced_files = set()
self.capture_sizes = capture_sizes
def add_traced_file(self, file_path: str):
self.traced_files.add(file_path)
def get_traced_files(self):
return self.traced_files
def get_capture_sizes(self):
return self.capture_sizes
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py
import copy
import dataclasses
from contextlib import contextmanager
@dataclasses.dataclass
class CompilationCounter:
num_models_seen: int = 0
num_graphs_seen: int = 0
# including the splitting ops
num_piecewise_graphs_seen: int = 0
# not including the splitting ops
num_piecewise_capturable_graphs_seen: int = 0
num_backend_compilations: int = 0
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
num_gpu_runner_capture_triggers: int = 0
# Number of CUDAGraphs captured
num_cudagraph_captured: int = 0
# InductorAdapter.compile calls
num_inductor_compiles: int = 0
# EagerAdapter.compile calls
num_eager_compiles: int = 0
# The number of time vLLM's compiler cache entry was updated
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
dynamo_as_is_count: int = 0
def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)
@contextmanager
def expect(self, **kwargs):
old = self.clone()
yield
for k, v in kwargs.items():
assert getattr(self, k) - getattr(old, k) == v, (
f"{k} not as expected, before it is {getattr(old, k)}"
f", after it is {getattr(self, k)}, "
f"expected diff is {v}"
)
compilation_counter = CompilationCounter()
import contextvars
import inspect
import logging
import os
import sys
import types
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
logger = logging.getLogger(__name__)
_COMPILE_ENABLED = contextvars.ContextVar("_COMPILE_ENABLED", default=False)
@contextmanager
def set_compiled(enabled: bool = True):
token = _COMPILE_ENABLED.set(enabled)
try:
yield
finally:
_COMPILE_ENABLED.reset(token)
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
Each stage also needs to handle its own finished_sending and
finished_recving in case of kv transfer.
"""
tensors: dict[str, torch.Tensor]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
def __init__(self, tensors):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value: torch.Tensor):
self.tensors[key] = value
def items(self):
return self.tensors.items()
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"
def _normalize_dims(dims, ndim: int):
dims = [dims] if isinstance(dims, int) else list(dims)
return [d if d >= 0 else ndim + d for d in dims]
class _MaybeIntermediateTensors:
"""Duck-typed check to support your IntermediateTensors without importing."""
def __init__(self, obj):
self.is_intermediate = hasattr(obj, "tensors") and isinstance(
getattr(obj, "tensors"), dict
)
self.obj = obj
def _mark_dynamic_on_value(val, dims):
if isinstance(val, torch.Tensor):
torch._dynamo.mark_dynamic(val, _normalize_dims(dims, val.ndim))
else:
mit = _MaybeIntermediateTensors(val)
if mit.is_intermediate:
for t in mit.obj.tensors.values():
torch._dynamo.mark_dynamic(t, _normalize_dims(dims, t.ndim))
# else: ignore (None or non-tensor)
def _infer_dynamic_arg_dims_from_annotations(forward_fn):
sig = inspect.signature(forward_fn)
dyn = {}
for name, p in sig.parameters.items():
ann = p.annotation
# Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name
if (
ann is torch.Tensor
or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor"
):
dyn[name] = 0
elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any(
getattr(a, "__name__", "") == "IntermediateTensors"
for a in getattr(ann, "__args__", [])
):
dyn[name] = 0
if not dyn:
raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.")
return dyn
def install_torch_compiled(
module: torch.nn.Module,
*,
dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None,
backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None,
compile_config: CompilationConfig = None,
fullgraph: bool = True,
graph_pool: Any = None,
):
unbound_fwd = module.__class__.forward
if not callable(unbound_fwd):
raise TypeError("module.__class__.forward must be callable")
original_code = unbound_fwd.__code__
dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd)
if backend_factory is None:
from sglang.srt.model_executor.compilation.backend import SGLangBackend
backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)(
gm, ex
)
compiled_codes: list[type(original_code)] = []
state = {"compiled": False, "compiled_callable": None}
def bytecode_hook(old_code, new_code):
if old_code is not original_code:
return
frame = sys._getframe()
while frame and frame.f_back:
frame = frame.f_back
if (
frame.f_code.co_name == "_compile"
and os.path.basename(frame.f_code.co_filename) == "convert_frame.py"
):
break
try:
dynamo_frame = frame.f_locals["frame"]
except Exception:
return
if dynamo_frame.f_code is not old_code:
return
if dynamo_frame.f_locals.get("self") is not module:
return
compiled_codes.append(new_code)
torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
def _ensure_compiled(self, *args, **kwargs):
"""Compile on first use (with flag ON)."""
if state["compiled"]:
return
# Mark dynamic dims only when we are about to compile
sig = inspect.signature(unbound_fwd)
ba = sig.bind(self, *args, **kwargs)
ba.apply_defaults()
for name, dims in (dyn_map or {}).items():
if name in ba.arguments:
val = ba.arguments[name]
if val is not None:
_mark_dynamic_on_value(val, dims)
# Avoid cross-instance cache reuse
torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__)
bound = types.MethodType(unbound_fwd, self)
compiled_callable = torch.compile(
bound, fullgraph=fullgraph, backend=backend_factory
)
# Trigger Dynamo so bytecode hook can capture
compiled_callable(*args, **kwargs)
state["compiled"] = True
state["compiled_callable"] = compiled_callable
def trampoline(self, *args, **kwargs):
use_compiled = _COMPILE_ENABLED.get()
if use_compiled:
if not state["compiled"]:
_ensure_compiled(self, *args, **kwargs)
compiled_callable = state["compiled_callable"]
return compiled_callable(*args, **kwargs)
else:
# Explicitly run the original uncompiled forward
return unbound_fwd(self, *args, **kwargs)
module.forward = types.MethodType(trampoline, module)
return module
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py
import contextlib
import copy
import hashlib
import os
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import torch._inductor.compile_fx
import torch.fx as fx
from sglang.srt.model_executor.compilation.compilation_counter import (
compilation_counter,
)
from sglang.srt.model_executor.compilation.inductor_pass import pass_context
class CompilerInterface:
"""
The interface for a compiler that can be used by vLLM.
"""
# The name of the compiler, e.g. inductor.
# This is a class-level attribute.
name: str
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
prefix can be used in combination with cache_dir to figure out the base
cache directory, e.g. there're multiple parts of model being compiled,
but we want to share the same cache directory for all of them.
e.g.
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
"""
pass
def compute_hash(self) -> str:
"""
Gather all the relevant information from the vLLM config,
to compute a hash so that we can cache the compiled model.
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
to check what information
is already considered by default. This function should only
consider the information that is specific to the compiler.
"""
return ""
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
the `example_inputs` have a dynamic shape. Otherwise, the
`runtime_shape` specifies the shape of the inputs. Right now we only
support one variable shape for all inputs, which is the batchsize
(number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
The function should return a compiled callable function, as well as
a handle that can be used to directly load the compiled function.
The handle should be a plain Python object, preferably a string or a
file path for readability.
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
`key` is required for StandaloneInductorAdapter, it specifies where to
save the compiled artifact. The compiled artifact gets saved to
`cache_dir/key`.
"""
return None, None
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None,
) -> Callable:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
The handle is the second return value of the `compile` function.
"""
raise NotImplementedError("caching is not supported")
def get_inductor_factors() -> list[Any]:
factors: list[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
return factors
class AlwaysHitShapeEnv:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def __init__(self) -> None:
self.guards: list[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
return True
def get_pruned_guards(self, *args, **kwargs):
return []
def produce_guards_expression(self, *args, **kwargs):
return ""
class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name = "inductor"
def compute_hash(self) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
if disable_cache:
return
# redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
# disable remote cache
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, runtime_shape)
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
hash_str, file_path = None, None
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
if cell.cell_contents.__code__.co_filename.startswith(
self.base_cache_dir
):
# this is the real file path compiled from Inductor
file_path = cell.cell_contents.__code__.co_filename
break
return inductor_compiled_graph
hijacked_compile_fx_inner = (
torch._inductor.compile_fx.compile_fx_inner
) # noqa
elif torch.__version__ >= "2.6":
# function renamed in 2.6
original_load_name = None
def hijacked_compile_fx_inner(*args, **kwargs):
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
if inductor_compiled_graph is not None:
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
code = cell.cell_contents.__code__
if code.co_filename.startswith(self.base_cache_dir):
# this is the real file path
# compiled from Inductor
file_path = code.co_filename
break
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str
hash_str = out[0]
return out
def _check_can_cache(*args, **kwargs):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv()
with ExitStack() as stack:
# hijack to get the compiled graph itself
if original_load_name is not None:
stack.enter_context(patch(original_load_name, hijack_load))
# for hijacking the hash of the compiled graph
stack.enter_context(
patch(
"torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash,
)
)
# for providing a dummy shape environment
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env,
)
)
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env,
)
)
# for forcing the graph to be cached
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache,
)
)
# Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context())
# Disable remote caching. When these are on, on remote cache-hit,
# the monkey-patched functions never actually get called.
# vLLM today assumes and requires the monkey-patched functions to
# get hit.
# TODO(zou3519): we're going to replace this all with
# standalone_compile sometime.
stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False)
)
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False)
)
stack.enter_context(
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
with pass_context(runtime_shape):
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config,
)
return compiled_graph, (hash_str, file_path)
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack:
exit_stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv(),
)
)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv(),
)
)
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
if torch.__version__.startswith("2.5"):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False
)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache directory and try again." # noqa
)
elif torch.__version__ >= "2.6":
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
constants = CompiledFxGraphConstantsWithGm(graph)
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, None, constants
)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache directory and try again." # noqa
)
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args):
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
# unpack the tuple if needed
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Present in torch>=2.6, it's used inside FxGraphCache in
torch==2.6 (but not after). It might also be used in various other
torch.compile internal functions.
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different level of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = True
config["coordinate_descent_tuning"] = True
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py
import dataclasses
import logging
from contextlib import ExitStack
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import torch
import torch.fx as fx
import sglang.srt.model_executor.compilation.weak_ref_tensor_jit
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
from sglang.srt.model_executor.compilation.compilation_counter import (
compilation_counter,
)
logger = logging.getLogger(__name__)
def weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
if isinstance(tensor, torch.Tensor):
# TODO(yuwei): introduce weak_ref_tensor from sgl_kernel
return torch.ops.jit_weak_ref_tensor.weak_ref_tensor(tensor)
return tensor
def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in cudagraph_capture_sizes
compiled: bool = False
runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
class CUDAPiecewiseBackend:
def __init__(
self,
graph: fx.GraphModule,
compile_config: CompilationConfig,
inductor_config: dict[str, Any],
graph_pool: Any,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
sglang_backend,
):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
"""
self.graph = graph
self.inductor_config = inductor_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.sglang_backend = sglang_backend
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.compile_sizes: set[int] = set([])
self.compile_config = compile_config
self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes())
self.first_run_finished = False
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = True
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
)
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.sglang_backend.compiler_manager.save_to_file()
def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape]
if entry.runnable is None:
entry.runnable = self.compiled_graph_for_general_shape
if entry.need_to_compile and not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = self.sglang_backend.compiler_manager.compile(
self.graph,
args,
self.inductor_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
)
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
# skip_cuda_graphs = get_forward_context().skip_cuda_graphs
# if not entry.use_cudagraph or skip_cuda_graphs:
# return entry.runnable(*args)
if entry.cudagraph is None:
if entry.num_finished_warmup < 1: # noqa
entry.num_finished_warmup += 1
return entry.runnable(*args)
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py
import logging
import operator
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from sglang.srt.model_executor.compilation.fx_utils import is_func
from sglang.srt.model_executor.compilation.inductor_pass import SGLangInductorPass
logger = logging.getLogger(__name__)
class FixFunctionalizationPass(SGLangInductorPass):
"""
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
After this pass, DCE (dead-code elimination) should never be run,
as de-functionalized nodes may appear as dead code.
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_fix_functionalization")
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
count += 1
self.dump_graph(graph, "before_fix_functionalization_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)
logger.debug(
"De-functionalized %s nodes, removed %s nodes", count, count_removed
)
self.dump_graph(graph, "after_fix_functionalization")
self.end_and_log()
def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
if isinstance(node_or_nodes, torch.fx.Node):
self.nodes_to_remove.append(node_or_nodes)
else:
self.nodes_to_remove.extend(node_or_nodes)
def defunctionalize(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
mutated_args: dict[int, Union[torch.fx.Node, str]],
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
):
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
See replace_users_with_mutated_args and insert_defunctionalized.
"""
self.replace_users_with_mutated_args(node, mutated_args)
self.insert_defunctionalized(graph, node, args=args)
self._remove(node)
def replace_users_with_mutated_args(
self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]
):
"""
Replace all getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
arg = mutated_args[idx]
arg = node.kwargs[arg] if isinstance(arg, str) else arg
user.replace_all_uses_with(arg)
self._remove(user)
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
"""
Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting.
"""
users = {}
for user in node.users:
if is_func(user, operator.getitem):
idx = user.args[1]
users[idx] = user
return users
def insert_defunctionalized(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
):
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(
node, auto_functionalized
), f"node must be auto-functionalized, is {node} instead"
# Create a new call to the original function
with graph.inserting_before(node):
function = node.args[0]
if args is None:
graph.call_function(function, kwargs=node.kwargs)
else:
# Args passed as strings refer to items in node.kwargs
args = tuple(
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
)
graph.call_function(function, args=args)
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py
import operator
from collections.abc import Iterable, Iterator
from typing import Optional
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload
def is_func(node: fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
return is_func(node, auto_functionalized) and node.args[0] == op
# Returns the first specified node with the given op (if it exists)
def find_specified_fn_maybe(
nodes: Iterable[fx.Node], op: OpOverload
) -> Optional[fx.Node]:
for node in nodes:
if node.target == op:
return node
return None
# Returns the first specified node with the given op
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_specified_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:
for node in nodes:
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
return node
return None
# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_auto_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
for user in node.users:
if is_func(user, operator.getitem) and user.args[1] == idx:
return user
return None
# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret
# An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op:
yield n
# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def get_only_user(node: fx.Node) -> fx.Node:
assert len(node.users) == 1
return next(iter(node.users))
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py
import hashlib
import inspect
import json
import logging
import time
import types
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
import torch
from torch import fx
from torch._dynamo.utils import lazy_format_graph_code
from torch._inductor.custom_graph_pass import CustomGraphPass
logger = logging.getLogger(__name__)
_pass_context = None
class PassContext:
def __init__(self, runtime_shape: Optional[int]):
self.runtime_shape = runtime_shape
def get_pass_context() -> PassContext:
"""Get the current pass context."""
assert _pass_context is not None
return _pass_context
@contextmanager
def pass_context(runtime_shape: Optional[int]):
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(runtime_shape)
try:
yield
finally:
_pass_context = prev_context
class InductorPass(CustomGraphPass):
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
def uuid(self) -> Any:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
pass result in recompilation.
By default, the object source is hashed.
"""
return InductorPass.hash_source(self)
@staticmethod
def hash_source(*srcs: Union[str, Any]):
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher = hashlib.sha256()
for src in srcs:
if isinstance(src, str):
src_str = src
elif isinstance(src, types.FunctionType):
src_str = inspect.getsource(src)
else:
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: dict[Any, Any]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_shape(self, shape: Optional[int]):
return True
class CallableInductorPass(InductorPass):
"""
This class is a wrapper for a callable that automatically provides an
implementation of the UUID.
"""
def __init__(
self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None
):
self.callable = callable
self._uuid = self.hash_source(callable) if uuid is None else uuid
def __call__(self, graph: torch.fx.Graph):
self.callable(graph)
def uuid(self) -> Any:
return self._uuid
class SGLangInductorPass(InductorPass):
def __init__(
self,
):
self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str):
lazy_format_graph_code(stage, graph.owning_module)
def begin(self):
self._start_time = time.perf_counter_ns()
def end_and_log(self):
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class PrinterInductorPass(SGLangInductorPass):
def __init__(self, name: str):
super().__init__()
self.name = name
def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, self.name)
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py
import logging
from torch import fx as fx
from sglang.srt.model_executor.compilation.fix_functionalization import (
FixFunctionalizationPass,
)
from sglang.srt.model_executor.compilation.inductor_pass import (
CustomGraphPass,
InductorPass,
SGLangInductorPass,
get_pass_context,
)
logger = logging.getLogger(__name__)
class PostGradPassManager(CustomGraphPass):
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
"""
def __init__(self):
self.passes: list[SGLangInductorPass] = []
def __call__(self, graph: fx.Graph):
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape):
pass_(graph)
# always run fix_functionalization last
self.fix_functionalization(graph)
def configure(
self,
):
self.pass_config = dict()
self.fix_functionalization = FixFunctionalizationPass()
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
def uuid(self):
"""
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
pass_manager_uuid = "fshdakhsa"
state = {"pass_config": pass_manager_uuid, "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
return InductorPass.hash_dict(state)
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, List, Optional
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@dataclass
class ForwardContext:
def __init__(self):
self.forward_batch = None
self.attention_layer = None
def set_forward_batch(self, forward_batch: ForwardBatch):
self.forward_batch = forward_batch
def set_attention_layers(self, layers: List[Any]):
self.attention_layers = layers
_forward_context: Optional[ForwardContext] = None
def get_forward_context() -> Optional[ForwardContext]:
if _forward_context is None:
return None
return _forward_context
@contextmanager
def set_forward_context(forward_batch: ForwardBatch, attention_layers: List[Any]):
global _forward_context
prev_forward_context = _forward_context
_forward_context = ForwardContext()
_forward_context.set_forward_batch(forward_batch)
_forward_context.set_attention_layers(attention_layers)
try:
yield
finally:
_forward_context = prev_forward_context
// Adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/ops.h
#include <torch/extension.h>
#include <vector>
static at::Tensor weak_ref_tensor(at::Tensor &tensor) {
TORCH_CHECK(tensor.is_cuda(), "weak_ref_tensor expects a CUDA tensor");
void *data_ptr = tensor.data_ptr();
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();
auto options = tensor.options();
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
return new_tensor;
}
TORCH_LIBRARY(jit_weak_ref_tensor, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
}
TORCH_LIBRARY_IMPL(jit_weak_ref_tensor, CUDA, ops) {
ops.impl("weak_ref_tensor", weak_ref_tensor);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
import os
import torch
from torch.utils.cpp_extension import load
_abs_path = os.path.dirname(os.path.abspath(__file__))
load(
name="weak_ref_tensor_ext",
sources=[f"{_abs_path}/weak_ref_tensor.cpp"],
extra_cflags=["-O3"],
)
x = torch.arange(12, device="cuda").reshape(3, 4)
y = torch.ops.jit_weak_ref_tensor.weak_ref_tensor(x)
print("alias:", x.data_ptr() == y.data_ptr())
......@@ -108,8 +108,15 @@ from sglang.srt.mem_cache.memory_pool import (
)
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
PiecewiseCudaGraphRunner,
)
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
......@@ -307,6 +314,26 @@ class ModelRunner:
self._model_update_group = {}
self._weights_send_group = {}
if (
self.server_args.enable_piecewise_cuda_graph
and self.can_run_piecewise_cuda_graph()
):
self.attention_layers = []
for layer in self.model.model.layers:
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
self.attention_layers.append(layer.self_attn.attn)
if len(self.attention_layers) < self.model_config.num_hidden_layers:
# TODO(yuwei): support Non-Standard GQA
log_info_on_rank0(
logger,
"Disable piecewise CUDA graph because some layers do not apply Standard GQA",
)
self.piecewise_cuda_graph_runner = None
else:
self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
else:
self.piecewise_cuda_graph_runner = None
def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args
......@@ -692,6 +719,7 @@ class ModelRunner:
pipeline_model_parallel_size=self.pp_size,
expert_model_parallel_size=self.moe_ep_size,
duplicate_tp_group=self.server_args.enable_pdmux,
torch_compile=self.server_args.enable_piecewise_cuda_graph,
)
initialize_dp_attention(
server_args=self.server_args,
......@@ -1411,6 +1439,27 @@ class ModelRunner:
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
)
def can_run_piecewise_cuda_graph(self):
if self.server_args.disable_cuda_graph:
log_info_on_rank0(
logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
)
return False
if self.server_args.enable_torch_compile:
log_info_on_rank0(
logger,
"Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
)
return False
if self.pp_size > 1:
# TODO(yuwei): support PP
log_info_on_rank0(
logger,
"Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
)
return False
return True
def init_memory_pool(
self,
total_gpu_memory: int,
......@@ -1932,6 +1981,11 @@ class ModelRunner:
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
if not self.is_generation:
kwargs["get_embedding"] = True
if self.piecewise_cuda_graph_runner is not None:
if self.piecewise_cuda_graph_runner.can_run(forward_batch):
return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
......
......@@ -417,7 +417,10 @@ class ServerArgs:
enable_single_batch_overlap: bool = False
tbo_token_distribution_threshold: float = 0.48
enable_torch_compile: bool = False
enable_piecewise_cuda_graph: bool = False
torch_compile_max_bs: int = 32
piecewise_cuda_graph_max_tokens: int = 4096
piecewise_cuda_graph_tokens: Optional[List[int]] = None
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
......@@ -675,6 +678,11 @@ class ServerArgs:
else:
self.cuda_graph_max_bs = max(self.cuda_graph_bs)
if self.piecewise_cuda_graph_tokens is None:
self.piecewise_cuda_graph_tokens = (
self._generate_piecewise_cuda_graph_tokens()
)
if self.mem_fraction_static is None:
# Constant meta data (e.g., from attention backend)
reserved_mem = 512
......@@ -753,6 +761,25 @@ class ServerArgs:
return capture_bs
def _generate_piecewise_cuda_graph_tokens(self):
"""
Generate the list of batch sizes for piecewise CUDA graph capture
based on piecewise_cuda_graph_max_tokens.
"""
capture_sizes = (
list(range(4, 33, 4))
+ list(range(48, 257, 16))
+ list(range(288, 513, 32))
+ list(range(640, 4096 + 1, 128))
+ list(range(4352, self.piecewise_cuda_graph_max_tokens + 1, 256))
)
capture_sizes = [
s for s in capture_sizes if s <= self.piecewise_cuda_graph_max_tokens
]
return capture_sizes
def _handle_hpu_backends(self):
if self.device == "hpu":
self.attention_backend = "torch_native"
......@@ -2649,12 +2676,29 @@ class ServerArgs:
action="store_true",
help="Optimize the model with torch.compile. Experimental feature.",
)
parser.add_argument(
"--enable-piecewise-cuda-graph",
action="store_true",
help="Optimize the model with piecewise cuda graph for extend/prefill only. Experimental feature.",
)
parser.add_argument(
"--piecewise-cuda-graph-tokens",
type=json_list_type,
default=ServerArgs.piecewise_cuda_graph_tokens,
help="Set the list of tokens when using piecewise cuda graph.",
)
parser.add_argument(
"--torch-compile-max-bs",
type=int,
default=ServerArgs.torch_compile_max_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument(
"--piecewise-cuda-graph-max-tokens",
type=int,
default=ServerArgs.piecewise_cuda_graph_max_tokens,
help="Set the maximum tokens when using piecewise cuda graph.",
)
parser.add_argument(
"--torchao-config",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment