Unverified Commit 51ef828f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] fix sym_tensor_indices (#12191)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent df450aa5
......@@ -624,9 +624,13 @@ class VllmBackend:
]
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.
from torch.fx.experimental.symbolic_shapes import is_symbolic
self.sym_tensor_indices = [
i for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
any(is_symbolic(d) for d in x.size())
]
# compiler managed cudagraph input buffers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment