Commit b1f510fa authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Add dimension and shape check (#2563)

Summary:
Don't allow users to input incorrect dimensions

Pull Request resolved: https://github.com/pytorch/audio/pull/2563

Reviewed By: carolineechen

Differential Revision: D38074360

Pulled By: skim0514

fbshipit-source-id: 7bcae515706eb358ca6f68c50c7c0ccace1c3f95
parent 6cee56ab
...@@ -516,8 +516,14 @@ class HDemucs(torch.nn.Module): ...@@ -516,8 +516,14 @@ class HDemucs(torch.nn.Module):
output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)` output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)`
""" """
if len(input.shape) == 2: if input.ndim != 3:
input = input.view(1, self.audio_channels, -1) raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}")
if input.shape[1] != self.audio_channels:
raise ValueError(
f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. "
f"Found:{input.shape[1]}."
)
x = input x = input
length = x.shape[-1] length = x.shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment