Unverified Commit eea55cca authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[1/N] torch.compile user interface design (#10237)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 9cdba966
...@@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context ...@@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
global_counter = 0 global_counter = 0
# create a library to hold the custom op # create a library to hold the custom op
...@@ -48,7 +47,11 @@ direct_register_custom_op( ...@@ -48,7 +47,11 @@ direct_register_custom_op(
@support_torch_compile @support_torch_compile
class SillyModel(nn.Module): class SillyModel(nn.Module):
def __init__(self) -> None: def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -74,11 +77,12 @@ class SillyModel(nn.Module): ...@@ -74,11 +77,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile(): def test_simple_piecewise_compile():
model = SillyModel()
directory = os.path.dirname(__file__) directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json") config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
model = SillyModel(vllm_config=VllmConfig(), prefix='')
inputs = torch.randn(100).cuda() inputs = torch.randn(100).cuda()
......
...@@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig ...@@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module): ...@@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig) -> None: def __init__(self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
self.embedding_tokens = nn.Embedding( self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size, num_embeddings=config.vocab_size,
...@@ -265,10 +272,9 @@ def run_model(llama_config, ...@@ -265,10 +272,9 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION) CompilationLevel.NO_COMPILATION)
set_compilation_config(None) set_compilation_config(None)
cls = LlamaModel model = LlamaModel(config=llama_config,
if use_compile: vllm_config=VllmConfig(),
cls = support_torch_compile(LlamaModel) prefix="").eval().cuda()
model = cls(llama_config).eval().cuda()
B = 16 # max batch size B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
...@@ -357,7 +363,6 @@ def test_toy_llama(): ...@@ -357,7 +363,6 @@ def test_toy_llama():
def benchmark(): def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench from triton.testing import do_bench
cls = support_torch_compile(LlamaModel)
# similar to llama 3.1-8B # similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096, llama_config = LlamaConfig(hidden_size=4096,
...@@ -390,7 +395,9 @@ def benchmark(): ...@@ -390,7 +395,9 @@ def benchmark():
else: else:
set_compilation_config(None) set_compilation_config(None)
model = cls(llama_config).eval().cuda().to(torch.bfloat16) model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)
B = 256 # max batch size B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo from vllm.utils import supports_dynamo
...@@ -110,26 +111,26 @@ def _support_torch_compile(cls: type, ...@@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
""" """
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # support decorating multiple times
# will handle the compilation, so we don't need to do anything here.
if envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo():
return cls return cls
# take care of method resolution order # take care of method resolution order
# make sure super().__init__ is called on the base class # make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher # other than TorchCompileWrapperWithCustomDispatcher
if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__: cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
# support decorating multiple times
cls.__bases__ = cls.__bases__ + (
TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ # type: ignore old_init = cls.__init__ # type: ignore
def __init__(self, *args, **kwargs): def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, *args, **kwargs) old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
if self.do_not_compile:
return
TorchCompileWrapperWithCustomDispatcher.__init__(self) TorchCompileWrapperWithCustomDispatcher.__init__(self)
cls.__init__ = __init__ # type: ignore cls.__init__ = __init__ # type: ignore
...@@ -138,7 +139,7 @@ def _support_torch_compile(cls: type, ...@@ -138,7 +139,7 @@ def _support_torch_compile(cls: type,
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't # e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside. # need to compile the model inside.
if torch.compiler.is_compiling(): if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
# the first compilation needs to have dynamic shapes marked # the first compilation needs to have dynamic shapes marked
......
...@@ -2041,12 +2041,15 @@ class VllmConfig: ...@@ -2041,12 +2041,15 @@ class VllmConfig:
simplifies passing around the distinct configurations in the codebase. simplifies passing around the distinct configurations in the codebase.
""" """
model_config: ModelConfig model_config: ModelConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig cache_config: CacheConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig parallel_config: ParallelConfig = field(default=None,
scheduler_config: SchedulerConfig init=True) # type: ignore
device_config: DeviceConfig scheduler_config: SchedulerConfig = field(default=None,
load_config: LoadConfig init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
lora_config: Optional[LoRAConfig] = None lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None speculative_config: Optional[SpeculativeConfig] = None
decoding_config: Optional[DecodingConfig] = None decoding_config: Optional[DecodingConfig] = None
...@@ -2091,11 +2094,14 @@ class VllmConfig: ...@@ -2091,11 +2094,14 @@ class VllmConfig:
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
""" """
self.model_config.verify_async_output_proc(self.parallel_config, if self.model_config is not None:
self.speculative_config, self.model_config.verify_async_output_proc(self.parallel_config,
self.device_config) self.speculative_config,
self.model_config.verify_with_parallel_config(self.parallel_config) self.device_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
if self.cache_config is not None:
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config: if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
...@@ -2149,4 +2155,4 @@ class VllmConfig: ...@@ -2149,4 +2155,4 @@ class VllmConfig:
self.scheduler_config.num_scheduler_steps, self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc, self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs) self.model_config.mm_processor_kwargs)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment