Unverified Commit d2f7486c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

convert torch.split return to list in RAFT (#7597)

parent 689ff292
......@@ -450,7 +450,12 @@ class RaftStereo(nn.Module):
hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1)
hidden_states.append(torch.tanh(hidden_state))
contexts.append(
torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1)
# mypy is technically correct here. The return type of `torch.split` was incorrectly annotated with
# `List[int]` although it should have been `Tuple[Tensor, ...]`. However, the latter is not supported by
# JIT and thus we have to keep the wrong annotation here and silence mypy.
torch.split( # type: ignore[arg-type]
context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1
)
)
_, Cf, Hf, Wf = fmap1.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment