Unverified Commit e4d652ea authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] integration with compilation control (#9058)

parent 78c0b416
...@@ -1137,10 +1137,9 @@ class EmbeddingSequenceGroupOutput( ...@@ -1137,10 +1137,9 @@ class EmbeddingSequenceGroupOutput(
return self.embeddings == other.embeddings return self.embeddings == other.embeddings
class IntermediateTensors( # cannot use msgspec.Struct here because Dynamo does not support it
msgspec.Struct, @dataclass
omit_defaults=True, # type: ignore[call-arg] class IntermediateTensors:
array_like=True): # type: ignore[call-arg]
"""For all pipeline stages except the last, we need to return the hidden """For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request. contains the hidden states and residuals for a request.
......
...@@ -18,6 +18,8 @@ import vllm.envs as envs ...@@ -18,6 +18,8 @@ import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
...@@ -1126,10 +1128,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1126,10 +1128,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. " "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!") "This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
from vllm.compilation.backends import vllm_backend and supports_dynamo():
from vllm.plugins import get_torch_compile_backend from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or vllm_backend backend = get_torch_compile_backend() or "eager"
self.model = torch.compile( self.model = torch.compile(
self.model, self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
...@@ -1289,6 +1291,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1289,6 +1291,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
batch_size=batch_size, batch_size=batch_size,
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
if self.model_config.enforce_eager:
batch_size_capture_list = []
with set_compile_context(batch_size_capture_list):
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment