Unverified Commit 1d726528 authored by Yuwei An's avatar Yuwei An Committed by GitHub
Browse files

Eager Compiler for Torch Compile (#11803)


Signed-off-by: default avatarOasis-Git <ayw.sirius19@gmail.com>
parent f4488e9d
...@@ -17,24 +17,30 @@ from torch._dispatch.python import enable_python_dispatcher ...@@ -17,24 +17,30 @@ from torch._dispatch.python import enable_python_dispatcher
from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.compilation_config import CompilationConfig
from sglang.srt.compilation.compilation_counter import compilation_counter from sglang.srt.compilation.compilation_counter import compilation_counter
from sglang.srt.compilation.compiler_interface import InductorAdaptor from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
from sglang.srt.compilation.pass_manager import PostGradPassManager from sglang.srt.compilation.pass_manager import PostGradPassManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def make_compiler(): def make_compiler(config: CompilationConfig):
return InductorAdaptor() if config.compiler == "eager":
return EagerAdapter()
elif config.compiler == "inductor":
return InductorAdaptor()
else:
raise ValueError(f"Unknown compiler: {config.compiler}")
class CompilerManager: class CompilerManager:
def __init__( def __init__(
self, self,
config: CompilationConfig,
): ):
self.cache = dict() self.cache = dict()
self.is_cache_updated = False self.is_cache_updated = False
self.compiler = make_compiler() self.compiler = make_compiler(config)
def compute_hash(self): def compute_hash(self):
return self.compiler.compute_hash() return self.compiler.compute_hash()
...@@ -348,7 +354,7 @@ class SGLangBackend: ...@@ -348,7 +354,7 @@ class SGLangBackend:
self.sym_tensor_indices = [] self.sym_tensor_indices = []
self.input_buffers = [] self.input_buffers = []
self.compiler_manager = CompilerManager() self.compiler_manager = CompilerManager(config)
self.inductor_config = { self.inductor_config = {
"enable_auto_functionalized_v2": False, "enable_auto_functionalized_v2": False,
} }
......
...@@ -5,9 +5,10 @@ from typing import List ...@@ -5,9 +5,10 @@ from typing import List
# TODO(Yuwei): support better compile config support # TODO(Yuwei): support better compile config support
class CompilationConfig: class CompilationConfig:
def __init__(self, capture_sizes: List[int]): def __init__(self, capture_sizes: List[int], compiler: str = "eager"):
self.traced_files = set() self.traced_files = set()
self.capture_sizes = capture_sizes self.capture_sizes = capture_sizes
self.compiler = compiler
def add_traced_file(self, file_path: str): def add_traced_file(self, file_path: str):
self.traced_files.add(file_path) self.traced_files.add(file_path)
......
...@@ -475,3 +475,29 @@ def set_inductor_config(config, runtime_shape): ...@@ -475,3 +475,29 @@ def set_inductor_config(config, runtime_shape):
# can be beneficial # can be beneficial
config["max_autotune"] = True config["max_autotune"] = True
config["coordinate_descent_tuning"] = True config["coordinate_descent_tuning"] = True
class EagerAdapter(CompilerInterface):
name = "eager"
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
num_graphs: int = 1,
) -> tuple[Optional[Callable], Optional[Any]]:
return graph, None
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None,
num_graphs: int = 1,
) -> Callable:
raise NotImplementedError("eager compilation is not supported")
...@@ -9,6 +9,7 @@ from unittest.mock import patch ...@@ -9,6 +9,7 @@ from unittest.mock import patch
import torch import torch
import torch.fx as fx import torch.fx as fx
import sglang.srt.compilation.weak_ref_tensor_jit # noqa: F401
from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.compilation_config import CompilationConfig
from sglang.srt.compilation.compilation_counter import compilation_counter from sglang.srt.compilation.compilation_counter import compilation_counter
......
...@@ -103,9 +103,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): ...@@ -103,9 +103,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
@contextmanager @contextmanager
def patch_model(model: torch.nn.Module): def patch_model(model: torch.nn.Module, compiler: str):
try: try:
_to_torch(model, reverse=False, num_tokens=16) if compiler != "eager":
_to_torch(model, reverse=False, num_tokens=16)
yield model yield model
finally: finally:
_to_torch(model, reverse=True, num_tokens=16) _to_torch(model, reverse=True, num_tokens=16)
...@@ -144,8 +145,13 @@ class PiecewiseCudaGraphRunner: ...@@ -144,8 +145,13 @@ class PiecewiseCudaGraphRunner:
assert ( assert (
self.model_runner.server_args.piecewise_cuda_graph_tokens is not None self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
), "piecewise_cuda_graph_tokens is not set" ), "piecewise_cuda_graph_tokens is not set"
assert self.model_runner.server_args.piecewise_cuda_graph_compiler in [
"eager",
"inductor",
], "By now, only eager and inductor are supported for piecewise cuda graph compiler."
self.compile_config = CompilationConfig( self.compile_config = CompilationConfig(
self.model_runner.server_args.piecewise_cuda_graph_tokens self.model_runner.server_args.piecewise_cuda_graph_tokens,
self.model_runner.server_args.piecewise_cuda_graph_compiler,
) )
# Batch sizes to capture # Batch sizes to capture
...@@ -179,7 +185,9 @@ class PiecewiseCudaGraphRunner: ...@@ -179,7 +185,9 @@ class PiecewiseCudaGraphRunner:
# Set graph pool id globally to be able to use symmetric memory # Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id(get_global_graph_memory_pool()) set_graph_pool_id(get_global_graph_memory_pool())
with patch_model(self.model_runner.model.model) as patched_model: with patch_model(
self.model_runner.model.model, self.compile_config.compiler
) as patched_model:
install_torch_compiled( install_torch_compiled(
patched_model, patched_model,
fullgraph=True, fullgraph=True,
...@@ -191,14 +199,14 @@ class PiecewiseCudaGraphRunner: ...@@ -191,14 +199,14 @@ class PiecewiseCudaGraphRunner:
with set_compiled(True): with set_compiled(True):
self.warmup_and_capture() self.warmup_and_capture()
# Capture # Capture
try: try:
with model_capture_mode(): with model_capture_mode():
self.capture() self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}" f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
) )
self.raw_num_tokens = 0 self.raw_num_tokens = 0
......
...@@ -436,6 +436,7 @@ class ServerArgs: ...@@ -436,6 +436,7 @@ class ServerArgs:
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
piecewise_cuda_graph_max_tokens: int = 4096 piecewise_cuda_graph_max_tokens: int = 4096
piecewise_cuda_graph_tokens: Optional[List[int]] = None piecewise_cuda_graph_tokens: Optional[List[int]] = None
piecewise_cuda_graph_compiler: str = "eager"
torchao_config: str = "" torchao_config: str = ""
enable_nan_detection: bool = False enable_nan_detection: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
...@@ -2815,6 +2816,13 @@ class ServerArgs: ...@@ -2815,6 +2816,13 @@ class ServerArgs:
default=ServerArgs.piecewise_cuda_graph_tokens, default=ServerArgs.piecewise_cuda_graph_tokens,
help="Set the list of tokens when using piecewise cuda graph.", help="Set the list of tokens when using piecewise cuda graph.",
) )
parser.add_argument(
"--piecewise-cuda-graph-compiler",
type=str,
default=ServerArgs.piecewise_cuda_graph_compiler,
help="Set the compiler for piecewise cuda graph. Choices are: eager, inductor.",
choices=["eager", "inductor"],
)
parser.add_argument( parser.add_argument(
"--torch-compile-max-bs", "--torch-compile-max-bs",
type=int, type=int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment