Commit da9d1627 authored by Yan Li's avatar Yan Li Committed by Facebook GitHub Bot
Browse files

Fix hybrid demucs tutorial for CUDA (#3017)

Summary:
Currently there will be a few errors when this tutorial is run with a CUDA device.

The reasons being:
- The source audio waveform is not properly moved to the GPU. The `to()` method is not in-place for Tensors, so we need to assign the return value of the method call to the variable (otherwise the Tensor would still be on the CPU).
- When performing further analysis and displaying of the output audio, we need to move them back from the GPU to the CPU. This is because some of the functions we call require the Tensor to be on the CPU (e.g. `stft()` and `bss_eval_sources()`).

Pull Request resolved: https://github.com/pytorch/audio/pull/3017

Reviewed By: mthrok

Differential Revision: D42828526

Pulled By: nateanl

fbshipit-source-id: c28bc855e79e3363a011f4a35a69aae1764e7762
parent 635d8cff
...@@ -208,7 +208,7 @@ def plot_spectrogram(stft, title="Spectrogram"): ...@@ -208,7 +208,7 @@ def plot_spectrogram(stft, title="Spectrogram"):
# We download the audio file from our storage. Feel free to download another file and use audio from a specific path # We download the audio file from our storage. Feel free to download another file and use audio from a specific path
SAMPLE_SONG = download_asset("tutorial-assets/hdemucs_mix.wav") SAMPLE_SONG = download_asset("tutorial-assets/hdemucs_mix.wav")
waveform, sample_rate = torchaudio.load(SAMPLE_SONG) # replace SAMPLE_SONG with desired path for different song waveform, sample_rate = torchaudio.load(SAMPLE_SONG) # replace SAMPLE_SONG with desired path for different song
waveform.to(device) waveform = waveform.to(device)
mixture = waveform mixture = waveform
# parameters # parameters
...@@ -285,23 +285,19 @@ bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav") ...@@ -285,23 +285,19 @@ bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav")
vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav") vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav")
other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav") other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav")
drums_spec = audios["drums"][:, frame_start: frame_end] drums_spec = audios["drums"][:, frame_start: frame_end].cpu()
drums, sample_rate = torchaudio.load(drums_original) drums, sample_rate = torchaudio.load(drums_original)
drums.to(device)
bass_spec = audios["bass"][:, frame_start: frame_end] bass_spec = audios["bass"][:, frame_start: frame_end].cpu()
bass, sample_rate = torchaudio.load(bass_original) bass, sample_rate = torchaudio.load(bass_original)
bass.to(device)
vocals_spec = audios["vocals"][:, frame_start: frame_end] vocals_spec = audios["vocals"][:, frame_start: frame_end].cpu()
vocals, sample_rate = torchaudio.load(vocals_original) vocals, sample_rate = torchaudio.load(vocals_original)
vocals.to(device)
other_spec = audios["other"][:, frame_start: frame_end] other_spec = audios["other"][:, frame_start: frame_end].cpu()
other, sample_rate = torchaudio.load(other_original) other, sample_rate = torchaudio.load(other_original)
other.to(device)
mix_spec = mixture[:, frame_start: frame_end] mix_spec = mixture[:, frame_start: frame_end].cpu()
###################################################################### ######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment