Unverified Commit 678bb248 authored by Josh Devins's avatar Josh Devins Committed by GitHub
Browse files

Make assertions only if actually chunking forward (#13598)

This moves the assertion on checking input dimensions into a block that will only be called if the function is actually going to do chunking forward. This is often not the case at inference time and PyTorch tracing a model with this assertion in it leads to a tracing warning.

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  input_tensor.shape[chunk_dim] == tensor_shape for input_tensor in input_tensors
parent 4a320f6c
......@@ -2294,10 +2294,6 @@ def apply_chunking_to_forward(
"""
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
tensor_shape = input_tensors[0].shape[chunk_dim]
assert all(
input_tensor.shape[chunk_dim] == tensor_shape for input_tensor in input_tensors
), "All input tenors have to be of the same shape"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
......@@ -2308,6 +2304,14 @@ def apply_chunking_to_forward(
)
if chunk_size > 0:
tensor_shape = input_tensors[0].shape[chunk_dim]
for input_tensor in input_tensors:
if input_tensor.shape[chunk_dim] != tensor_shape:
raise ValueError(
f"All input tenors have to be of the same shape: {tensor_shape}, "
f"found shape {input_tensor.shape[chunk_dim]}"
)
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment