Unverified Commit 3224ea99 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[torch.compile] Add encoder tag for compilation (#30489)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
parent 3a20450d
...@@ -463,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -463,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
# the tag for the part of model being compiled, # the tag for the part of model being compiled,
# e.g. backbone/eagle_head # e.g. backbone/eagle_head
model_tag: str = "backbone" model_tag: str = "backbone"
model_is_encoder: bool = False
@contextmanager @contextmanager
def set_model_tag(tag: str): def set_model_tag(tag: str, is_encoder: bool = False):
"""Context manager to set the model tag.""" """Context manager to set the model tag."""
global model_tag global model_tag
global model_is_encoder
assert tag != model_tag, ( assert tag != model_tag, (
f"Model tag {tag} is the same as the current tag {model_tag}." f"Model tag {tag} is the same as the current tag {model_tag}."
) )
old_tag = model_tag old_tag = model_tag
old_is_encoder = model_is_encoder
model_tag = tag model_tag = tag
model_is_encoder = is_encoder
try: try:
yield yield
finally: finally:
model_tag = old_tag model_tag = old_tag
model_is_encoder = old_is_encoder
class VllmBackend: class VllmBackend:
...@@ -523,6 +529,9 @@ class VllmBackend: ...@@ -523,6 +529,9 @@ class VllmBackend:
# them, e.g. backbone (default), eagle_head, etc. # them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag self.prefix = prefix or model_tag
# Mark compilation for encoder.
self.is_encoder = model_is_encoder
# Passes to run on the graph post-grad. # Passes to run on the graph post-grad.
self.pass_manager = resolve_obj_by_qualname( self.pass_manager = resolve_obj_by_qualname(
current_platform.get_pass_manager_cls() current_platform.get_pass_manager_cls()
......
...@@ -53,12 +53,7 @@ class PiecewiseBackend: ...@@ -53,12 +53,7 @@ class PiecewiseBackend:
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1 self.is_full_graph = total_piecewise_compiles == 1
# TODO: we need to generalize encoder compilation to other models self.is_encoder_compilation = vllm_backend.is_encoder
self.is_encoder_compilation = vllm_backend.prefix in [
"Qwen2_5_VisionPatchEmbed",
"Qwen2_5_VisionPatchMerger",
"Qwen2_5_VisionBlock",
]
self.compile_ranges = self.compilation_config.get_compile_ranges() self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation: if self.is_encoder_compilation:
......
...@@ -612,7 +612,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -612,7 +612,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
# DO NOT MOVE THIS IMPORT # DO NOT MOVE THIS IMPORT
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
with set_model_tag("Qwen2_5_VisionPatchEmbed"): with set_model_tag("Qwen2_5_VisionPatchEmbed", is_encoder=True):
self.patch_embed = Qwen2_5_VisionPatchEmbed( self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size, patch_size=patch_size,
temporal_patch_size=temporal_patch_size, temporal_patch_size=temporal_patch_size,
...@@ -651,7 +651,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -651,7 +651,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
f"Qwen2.5-VL does not support {self.attn_backend} backend now." f"Qwen2.5-VL does not support {self.attn_backend} backend now."
) )
with set_model_tag("Qwen2_5_VisionBlock"): with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True):
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
Qwen2_5_VisionBlock( Qwen2_5_VisionBlock(
...@@ -670,7 +670,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -670,7 +670,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
] ]
) )
with set_model_tag("Qwen2_5_VisionPatchMerger"): with set_model_tag("Qwen2_5_VisionPatchMerger", is_encoder=True):
self.merger = Qwen2_5_VisionPatchMerger( self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size, d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size, context_dim=self.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment