Unverified Commit 90200860 authored by Aditya Raj's avatar Aditya Raj Committed by GitHub
Browse files

[BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError...


[BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() TypeError in function prepare_latents caused by audio_vae_length (#10306)

[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"

torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

in function prepare_latents:
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
...
audio = initial_audio_waveforms.new_zeros(audio_shape)

audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent c8ee4af2
......@@ -446,7 +446,7 @@ class StableAudioPipeline(DiffusionPipeline):
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
)
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
# check num_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment