Unverified Commit ce6bf3a2 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] avoid Dynamo guard evaluation overhead (#7898)


Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 3cdfe1f3
...@@ -12,4 +12,4 @@ remove_docker_container ...@@ -12,4 +12,4 @@ remove_docker_container
# For HF_TOKEN. # For HF_TOKEN.
source /etc/environment source /etc/environment
# Run a simple end-to-end example. # Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
...@@ -173,6 +173,7 @@ steps: ...@@ -173,6 +173,7 @@ steps:
- vllm/ - vllm/
commands: commands:
- pytest -v -s ./compile/test_full_graph.py - pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_wrapper.py
- label: Vision Language Models Test # 42min - label: Vision Language Models Test # 42min
......
from typing import Optional
import torch
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
class MyMod(torch.nn.Module):
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
if cache is not None:
return x + cache
return x * 2
class MyWrapper(TorchCompileWrapperWithCustomDispacther):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
return self.model(x, cache)
def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# let torch.compile compile twice
if len(self.compiled_codes) == 2:
dispatch_id = 0 if cache is None else 1
with self.dispatch_to_code(dispatch_id):
return self.forward(x, cache)
else:
return self.compiled_callable(x, cache)
def test_torch_compile_wrapper():
mod = MyMod()
wrappers = []
for i in range(3):
torch._dynamo.reset()
wrapper = MyWrapper(mod)
wrappers.append(wrapper)
x = torch.tensor([1])
wrapper(x, None) # profile run, compile
# create a cache tensor
cache = torch.tensor([2])
wrapper(x, cache) # warm up with cache, recompile
# for new input, dispatch to the compiled code directly
new_x = torch.tensor([3])
assert wrapper(new_x,
None).item() == 6 # dispatch to the first compiled code
assert wrapper(
new_x, cache).item() == 5 # dispatch to the second compiled code
for wrapper in wrappers:
# make sure they have independent compiled codes
assert len(wrapper.compiled_codes) == 2
from ..utils import compare_two_settings
def test_custom_dispatcher():
compare_two_settings("google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
env2={})
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, List
import torch
import vllm.envs as envs
class TorchCompileWrapperWithCustomDispacther:
"""
A wrapper class for torch.compile, with a custom dispatch logic.
Subclasses should:
1. Implement the forward method
2. Implement the dispatch logic in the __call__ method
It can use `self.compiled_codes` to access the compiled bytecode,
and `with self.dispatch_to_code(index):` to dispatch to
the compiled code.
3. Implement the `__init__` method to determine how to call
`torch.compile` over the forward method.
"""
def __init__(self, compiled_callable: Callable):
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
# read the env var to determine whether to use the custom dispatcher
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
NOTE: this function can have additional arguments beyond the forward
method, for directly dispatching to the compiled code.
"""
return self.compiled_callable(*args, **kwargs)
@abstractmethod
def forward(self, *args, **kwargs):
...
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object:
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while True:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == old_code
if frame.f_locals["self"] is not self:
return
self.compiled_codes.append(new_code)
@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.
Why does this work? Because Dynamo guarantees that the compiled
bytecode has exactly the same arguments, cell variables, and free
variables as the original code. Therefore we can directly switch
the code object in the function and call it.
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa
self.__class__.forward.__code__ = self.compiled_codes[index]
yield
self.__class__.forward.__code__ = self.original_code_object
...@@ -196,6 +196,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -196,6 +196,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Internal flag to enable Dynamo graph capture # Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE": "VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
lambda:
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
......
...@@ -10,6 +10,7 @@ import torch_xla.core.xla_model as xm ...@@ -10,6 +10,7 @@ import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -144,11 +145,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -144,11 +145,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
) )
model = model.eval() model = model.eval()
xm.wait_device_ops() xm.wait_device_ops()
model = ModelWrapper(model) self.model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def _dummy_run( def _dummy_run(
self, self,
...@@ -235,8 +232,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -235,8 +232,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0) torch._dynamo.mark_dynamic(p, 0)
# Dummy run. # Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, self.model(token_ids,
num_samples, kv_caches) position_ids,
attn_metadata,
input_lens,
t,
p,
num_samples,
kv_caches,
is_prompt=is_prompt)
def warmup_model( def warmup_model(
self, self,
...@@ -530,7 +534,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -530,7 +534,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
if getattr(arg, "context_lens", None) is not None: if getattr(arg, "context_lens", None) is not None:
arg.context_lens = arg.context_lens.to(self.device) arg.context_lens = arg.context_lens.to(self.device)
new_args.append(arg) new_args.append(arg)
return self.model(*new_args) return self.model(*new_args, is_prompt=is_prompt)
num_prefills = model_input.attn_metadata.num_prefills num_prefills = model_input.attn_metadata.num_prefills
is_prompt = num_prefills > 0 is_prompt = num_prefills > 0
...@@ -601,11 +605,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -601,11 +605,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
return [SamplerOutput(sampler_outputs)] return [SamplerOutput(sampler_outputs)]
class ModelWrapper(nn.Module): class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
super().__init__()
self.model = model self.model = model
compiled_callable = torch.compile(self.forward,
backend="openxla",
fullgraph=True,
dynamic=False)
super().__init__(compiled_callable)
def __call__(self, *args, is_prompt: bool, **kwargs):
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
# not fully compiled yet, or not using the custom dispatcher,
# let PyTorch handle it
return self.compiled_callable(*args, **kwargs)
# the 3 compiled codes are:
# 0: for profiling
# 1: for prompt
# 2: for decode
# dispatch to the compiled code directly, skip PyTorch
if is_prompt:
with self.dispatch_to_code(1):
return self.forward(*args, **kwargs)
else:
with self.dispatch_to_code(2):
return self.forward(*args, **kwargs)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment