Unverified Commit 38bce1d4 authored by Basile Van Hoorick's avatar Basile Van Hoorick Committed by GitHub
Browse files

Make `pos` optional to avoid crashing `PerceiverModel` operation (#15972)

Updates `PerceiverAudioPreprocessor` `forward()` implementation to match most other preprocessors / postprocessors
parent cec89e1a
...@@ -3264,7 +3264,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor): ...@@ -3264,7 +3264,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
return inputs_with_pos, inputs return inputs_with_pos, inputs
def forward(self, inputs, pos, network_input_is_1d: bool = True): def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch]) inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
inputs, inputs_without_pos = self._build_network_inputs(inputs, pos) inputs, inputs_without_pos = self._build_network_inputs(inputs, pos)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment