Unverified Commit 1d726528 authored by Yuwei An's avatar Yuwei An Committed by GitHub
Browse files

Eager Compiler for Torch Compile (#11803)


Signed-off-by: default avatarOasis-Git <ayw.sirius19@gmail.com>
parent f4488e9d
......@@ -17,24 +17,30 @@ from torch._dispatch.python import enable_python_dispatcher
from sglang.srt.compilation.compilation_config import CompilationConfig
from sglang.srt.compilation.compilation_counter import compilation_counter
from sglang.srt.compilation.compiler_interface import InductorAdaptor
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
from sglang.srt.compilation.pass_manager import PostGradPassManager
logger = logging.getLogger(__name__)
def make_compiler():
return InductorAdaptor()
def make_compiler(config: CompilationConfig):
if config.compiler == "eager":
return EagerAdapter()
elif config.compiler == "inductor":
return InductorAdaptor()
else:
raise ValueError(f"Unknown compiler: {config.compiler}")
class CompilerManager:
def __init__(
self,
config: CompilationConfig,
):
self.cache = dict()
self.is_cache_updated = False
self.compiler = make_compiler()
self.compiler = make_compiler(config)
def compute_hash(self):
return self.compiler.compute_hash()
......@@ -348,7 +354,7 @@ class SGLangBackend:
self.sym_tensor_indices = []
self.input_buffers = []
self.compiler_manager = CompilerManager()
self.compiler_manager = CompilerManager(config)
self.inductor_config = {
"enable_auto_functionalized_v2": False,
}
......
......@@ -5,9 +5,10 @@ from typing import List
# TODO(Yuwei): support better compile config support
class CompilationConfig:
def __init__(self, capture_sizes: List[int]):
def __init__(self, capture_sizes: List[int], compiler: str = "eager"):
self.traced_files = set()
self.capture_sizes = capture_sizes
self.compiler = compiler
def add_traced_file(self, file_path: str):
self.traced_files.add(file_path)
......
......@@ -475,3 +475,29 @@ def set_inductor_config(config, runtime_shape):
# can be beneficial
config["max_autotune"] = True
config["coordinate_descent_tuning"] = True
class EagerAdapter(CompilerInterface):
name = "eager"
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
num_graphs: int = 1,
) -> tuple[Optional[Callable], Optional[Any]]:
return graph, None
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None,
num_graphs: int = 1,
) -> Callable:
raise NotImplementedError("eager compilation is not supported")
......@@ -9,6 +9,7 @@ from unittest.mock import patch
import torch
import torch.fx as fx
import sglang.srt.compilation.weak_ref_tensor_jit # noqa: F401
from sglang.srt.compilation.compilation_config import CompilationConfig
from sglang.srt.compilation.compilation_counter import compilation_counter
......
......@@ -103,9 +103,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
@contextmanager
def patch_model(model: torch.nn.Module):
def patch_model(model: torch.nn.Module, compiler: str):
try:
_to_torch(model, reverse=False, num_tokens=16)
if compiler != "eager":
_to_torch(model, reverse=False, num_tokens=16)
yield model
finally:
_to_torch(model, reverse=True, num_tokens=16)
......@@ -144,8 +145,13 @@ class PiecewiseCudaGraphRunner:
assert (
self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
), "piecewise_cuda_graph_tokens is not set"
assert self.model_runner.server_args.piecewise_cuda_graph_compiler in [
"eager",
"inductor",
], "By now, only eager and inductor are supported for piecewise cuda graph compiler."
self.compile_config = CompilationConfig(
self.model_runner.server_args.piecewise_cuda_graph_tokens
self.model_runner.server_args.piecewise_cuda_graph_tokens,
self.model_runner.server_args.piecewise_cuda_graph_compiler,
)
# Batch sizes to capture
......@@ -179,7 +185,9 @@ class PiecewiseCudaGraphRunner:
# Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id(get_global_graph_memory_pool())
with patch_model(self.model_runner.model.model) as patched_model:
with patch_model(
self.model_runner.model.model, self.compile_config.compiler
) as patched_model:
install_torch_compiled(
patched_model,
fullgraph=True,
......@@ -191,14 +199,14 @@ class PiecewiseCudaGraphRunner:
with set_compiled(True):
self.warmup_and_capture()
# Capture
try:
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)
# Capture
try:
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)
self.raw_num_tokens = 0
......
......@@ -436,6 +436,7 @@ class ServerArgs:
torch_compile_max_bs: int = 32
piecewise_cuda_graph_max_tokens: int = 4096
piecewise_cuda_graph_tokens: Optional[List[int]] = None
piecewise_cuda_graph_compiler: str = "eager"
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
......@@ -2815,6 +2816,13 @@ class ServerArgs:
default=ServerArgs.piecewise_cuda_graph_tokens,
help="Set the list of tokens when using piecewise cuda graph.",
)
parser.add_argument(
"--piecewise-cuda-graph-compiler",
type=str,
default=ServerArgs.piecewise_cuda_graph_compiler,
help="Set the compiler for piecewise cuda graph. Choices are: eager, inductor.",
choices=["eager", "inductor"],
)
parser.add_argument(
"--torch-compile-max-bs",
type=int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment