"vscode:/vscode.git/clone" did not exist on "7c2e91c4e0414d076033d038b3c4acf236198cf3"
Unverified Commit db8d4a4a authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[BugFix][Graph] fix: handle empty sym_shape_indices in PiecewiseBackend. (#39395)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent fc701c80
......@@ -222,3 +222,47 @@ def test_model_specialization_with_evaluate_guards(
torch.randn(1, 10).cuda(),
is_01_specialization=True,
)
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_piecewise_backend_empty_sym_shape_indices():
"""Test that PiecewiseBackend handles empty sym_shape_indices correctly.
When all inputs have static shapes (no torch.SymInt), sym_shape_indices
will be empty. The fix in PiecewiseBackend.__call__ handles this case
by using the first compiled range_entry.
"""
gc.collect()
torch.accelerator.empty_cache()
torch.accelerator.synchronize()
# Use small max_model_len and max_num_batched_tokens to encourage
# static shape compilation with empty sym_shape_indices
llm = LLM(
model="Qwen/Qwen3-0.6B",
max_model_len=512,
max_num_batched_tokens=1,
compilation_config={
"mode": CompilationMode.VLLM_COMPILE,
"dynamic_shapes_config": {
"type": DynamicShapesType.BACKED.value,
},
},
)
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
# Generate with static shape inputs
output = llm.generate("Hello, my name is", sampling_params=sampling_params)
result = output[0].outputs[0].text
assert len(result) > 0, "Should generate non-empty output"
# Generate again to verify compilation works with empty sym_shape_indices
output = llm.generate("The capital of France is", sampling_params=sampling_params)
result = output[0].outputs[0].text
assert len(result) > 0, "Should generate non-empty output on second run"
del llm
gc.collect()
torch.accelerator.empty_cache()
torch.accelerator.synchronize()
......@@ -354,12 +354,22 @@ class PiecewiseBackend:
return None
def __call__(self, *args: Any) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
if self.sym_shape_indices:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape: {runtime_shape} out of considered ranges: "
f"{self.compile_ranges}"
)
else:
# All inputs have static shapes; use the only compiled range_entry
compiled_entries = [re for re in self.range_entries.values() if re.compiled]
assert len(compiled_entries) == 1, (
f"Expected exactly one compiled range_entry for static shape "
f"compilation, but found {len(compiled_entries)}"
)
range_entry = compiled_entries[0]
assert range_entry is not None, (
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
)
assert range_entry.compiled, (
"All ranges should be compiled or loaded up front in "
"PiecewiseBackend.__init__. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment