Unverified Commit 51ef828f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] fix sym_tensor_indices (#12191)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent df450aa5
...@@ -624,9 +624,13 @@ class VllmBackend: ...@@ -624,9 +624,13 @@ class VllmBackend:
] ]
# index of tensors that have symbolic shapes (batch size) # index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.
from torch.fx.experimental.symbolic_shapes import is_symbolic
self.sym_tensor_indices = [ self.sym_tensor_indices = [
i for i, x in enumerate(fake_args) i for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
any(is_symbolic(d) for d in x.size())
] ]
# compiler managed cudagraph input buffers # compiler managed cudagraph input buffers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment