Unverified Commit 0db5439d authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Bugfix][torch2.10] Fix test_qwen2_5_vl_compilation with 2.10 RC (#30822)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 28d15ab5
...@@ -520,6 +520,7 @@ class VllmBackend: ...@@ -520,6 +520,7 @@ class VllmBackend:
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
is_encoder: bool = False,
): ):
# if the model is initialized with a non-empty prefix, # if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix, # then usually it's enough to use that prefix,
...@@ -530,7 +531,7 @@ class VllmBackend: ...@@ -530,7 +531,7 @@ class VllmBackend:
self.prefix = prefix or model_tag self.prefix = prefix or model_tag
# Mark compilation for encoder. # Mark compilation for encoder.
self.is_encoder = model_is_encoder self.is_encoder = is_encoder or model_is_encoder
# Passes to run on the graph post-grad. # Passes to run on the graph post-grad.
self.pass_manager = resolve_obj_by_qualname( self.pass_manager = resolve_obj_by_qualname(
...@@ -797,7 +798,7 @@ class VllmBackend: ...@@ -797,7 +798,7 @@ class VllmBackend:
or not self.compilation_config.cudagraph_copy_inputs or not self.compilation_config.cudagraph_copy_inputs
): ):
return VllmSerializableFunction( return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm graph, example_inputs, self.prefix, self.split_gm, self.is_encoder
) )
# index of tensors that have symbolic shapes (batch size) # index of tensors that have symbolic shapes (batch size)
...@@ -835,5 +836,5 @@ class VllmBackend: ...@@ -835,5 +836,5 @@ class VllmBackend:
return self.split_gm(*list_args) return self.split_gm(*list_args)
return VllmSerializableFunction( return VllmSerializableFunction(
graph, example_inputs, self.prefix, copy_and_call graph, example_inputs, self.prefix, copy_and_call, self.is_encoder
) )
...@@ -37,12 +37,15 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -37,12 +37,15 @@ class VllmSerializableFunction(SerializableCallable):
serializing the Dynamo fx graph plus example inputs. serializing the Dynamo fx graph plus example inputs.
""" """
def __init__(self, graph_module, example_inputs, prefix, optimized_call): def __init__(
self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False
):
assert isinstance(graph_module, torch.fx.GraphModule) assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module self.graph_module = graph_module
self.example_inputs = example_inputs self.example_inputs = example_inputs
self.prefix = prefix self.prefix = prefix
self.optimized_call = optimized_call self.optimized_call = optimized_call
self.is_encoder = is_encoder
self.shape_env = None self.shape_env = None
sym_input = next( sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
...@@ -106,7 +109,10 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -106,7 +109,10 @@ class VllmSerializableFunction(SerializableCallable):
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile() state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"]) is_encoder = state.get("is_encoder", False)
vllm_backend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder
)
def optimized_call(*example_inputs): def optimized_call(*example_inputs):
""" """
......
...@@ -170,8 +170,7 @@ class PiecewiseBackend: ...@@ -170,8 +170,7 @@ class PiecewiseBackend:
range_entry = self._find_range_for_shape(runtime_shape) range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, ( assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} " f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
"[1, max_num_batched_tokens]"
) )
self._maybe_compile_for_range_entry(range_entry, args) self._maybe_compile_for_range_entry(range_entry, args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment