Unverified Commit 714c6e0e authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[torch.compile][BE] Modify cudagraph callable to check for is_forward_context_set (#36288)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent 0fefd00e
...@@ -34,9 +34,6 @@ relies on caching artifacts to reduce start time, we must properly propagate the ...@@ -34,9 +34,6 @@ relies on caching artifacts to reduce start time, we must properly propagate the
with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder
components (see Compile Range Integration). components (see Compile Range Integration).
3. `with set_forward_context` context manager should be used around the nn.Module's forward call. This will properly forward the vllm_config which is needed
for torch.compile integration.
### CompilationConfig ### CompilationConfig
With the exception of `compile_mm_encoder: true`, the multimodal encoder will inherit from the same compilation config as the text LLM. We may extend With the exception of `compile_mm_encoder: true`, the multimodal encoder will inherit from the same compilation config as the text LLM. We may extend
......
...@@ -16,7 +16,11 @@ from vllm.compilation.counter import compilation_counter ...@@ -16,7 +16,11 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.forward_context import (
BatchDescriptor,
get_forward_context,
is_forward_context_available,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -224,6 +228,12 @@ class CUDAGraphWrapper: ...@@ -224,6 +228,12 @@ class CUDAGraphWrapper:
self.concrete_cudagraph_entries.clear() self.concrete_cudagraph_entries.clear()
def __call__(self, *args: Any, **kwargs: Any) -> Any | None: def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
if not is_forward_context_available():
# No forward context means we are outside the normal
# inference path (e.g. a vision encoder forward pass).
# Just run the underlying function without cudagraphs.
return self.runnable(*args, **kwargs)
forward_context = get_forward_context() forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
......
...@@ -38,7 +38,6 @@ from vllm.compilation.decorators import ( ...@@ -38,7 +38,6 @@ from vllm.compilation.decorators import (
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -872,10 +871,7 @@ class Llama4ForConditionalGeneration( ...@@ -872,10 +871,7 @@ class Llama4ForConditionalGeneration(
if image_input is None: if image_input is None:
return [] return []
with ( return self._process_image_input(image_input)
set_forward_context(None, self.vllm_config),
):
return self._process_image_input(image_input)
def forward( def forward(
self, self,
......
...@@ -49,7 +49,6 @@ from vllm.compilation.decorators import ( ...@@ -49,7 +49,6 @@ from vllm.compilation.decorators import (
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.attention import MMEncoderAttention
...@@ -1207,13 +1206,12 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1207,13 +1206,12 @@ class Qwen2_5_VLForConditionalGeneration(
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else: else:
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
with set_forward_context(None, self.vllm_config): if self.use_data_parallel:
if self.use_data_parallel: return run_dp_sharded_mrope_vision_model(
return run_dp_sharded_mrope_vision_model( self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" )
) else:
else: image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
...@@ -1262,18 +1260,15 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1262,18 +1260,15 @@ class Qwen2_5_VLForConditionalGeneration(
video_embeds = video_input["video_embeds"].type(self.visual.dtype) video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else: else:
pixel_values_videos = video_input["pixel_values_videos"] pixel_values_videos = video_input["pixel_values_videos"]
with set_forward_context(None, self.vllm_config): if self.use_data_parallel:
if self.use_data_parallel: return run_dp_sharded_mrope_vision_model(
return run_dp_sharded_mrope_vision_model( self.visual,
self.visual, pixel_values_videos,
pixel_values_videos, grid_thw_list,
grid_thw_list, rope_type="rope_3d",
rope_type="rope_3d", )
) else:
else: video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw_list
)
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment