Unverified Commit 714c6e0e authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[torch.compile][BE] Modify cudagraph callable to check for is_forward_context_set (#36288)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent 0fefd00e
......@@ -34,9 +34,6 @@ relies on caching artifacts to reduce start time, we must properly propagate the
with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder
components (see Compile Range Integration).
3. `with set_forward_context` context manager should be used around the nn.Module's forward call. This will properly forward the vllm_config which is needed
for torch.compile integration.
### CompilationConfig
With the exception of `compile_mm_encoder: true`, the multimodal encoder will inherit from the same compilation config as the text LLM. We may extend
......
......@@ -16,7 +16,11 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.forward_context import (
BatchDescriptor,
get_forward_context,
is_forward_context_available,
)
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
......@@ -224,6 +228,12 @@ class CUDAGraphWrapper:
self.concrete_cudagraph_entries.clear()
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
if not is_forward_context_available():
# No forward context means we are outside the normal
# inference path (e.g. a vision encoder forward pass).
# Just run the underlying function without cudagraphs.
return self.runnable(*args, **kwargs)
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
......
......@@ -38,7 +38,6 @@ from vllm.compilation.decorators import (
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
......@@ -872,10 +871,7 @@ class Llama4ForConditionalGeneration(
if image_input is None:
return []
with (
set_forward_context(None, self.vllm_config),
):
return self._process_image_input(image_input)
return self._process_image_input(image_input)
def forward(
self,
......
......@@ -49,7 +49,6 @@ from vllm.compilation.decorators import (
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.attention import MMEncoderAttention
......@@ -1207,13 +1206,12 @@ class Qwen2_5_VLForConditionalGeneration(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"]
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
......@@ -1262,18 +1260,15 @@ class Qwen2_5_VLForConditionalGeneration(
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"]
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d",
)
else:
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw_list
)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d",
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment