Commit 3bd769f8 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Default to float64 for cumsum in oscillator_bank (#3083)

Summary:
oscillator_bank perform cumsum on large number of elements and typically, float32 is not good enough.

This PR makes the cumsum operation default to float64, so that the result is better.

Pull Request resolved: https://github.com/pytorch/audio/pull/3083

Reviewed By: nateanl

Differential Revision: D44257182

Pulled By: mthrok

fbshipit-source-id: a38a465d33559a415e8c744e61292f4fab64b0e1
parent 28192ff4
...@@ -11,6 +11,7 @@ def oscillator_bank( ...@@ -11,6 +11,7 @@ def oscillator_bank(
amplitudes: torch.Tensor, amplitudes: torch.Tensor,
sample_rate: float, sample_rate: float,
reduction: str = "sum", reduction: str = "sum",
dtype: Optional[torch.dtype] = torch.float64,
) -> torch.Tensor: ) -> torch.Tensor:
"""Synthesize waveform from the given instantaneous frequencies and amplitudes. """Synthesize waveform from the given instantaneous frequencies and amplitudes.
...@@ -38,6 +39,8 @@ def oscillator_bank( ...@@ -38,6 +39,8 @@ def oscillator_bank(
sample_rate (float): Sample rate sample_rate (float): Sample rate
reduction (str): Reduction to perform. reduction (str): Reduction to perform.
Valid values are ``"sum"``, ``"mean"`` or ``"none"``. Default: ``"sum"`` Valid values are ``"sum"``, ``"mean"`` or ``"none"``. Default: ``"sum"``
dtype (torch.dtype or None, optional): The data type on which cumulative sum operation is performed.
Default: ``torch.float64``. Pass ``None`` to disable the casting.
Returns: Returns:
Tensor: Tensor:
...@@ -64,16 +67,11 @@ def oscillator_bank( ...@@ -64,16 +67,11 @@ def oscillator_bank(
) )
amplitudes = torch.where(invalid, 0.0, amplitudes) amplitudes = torch.where(invalid, 0.0, amplitudes)
# Note:
# In magenta/ddsp, there is an option to reduce the number of summation to reduce
# the accumulation error.
# https://github.com/magenta/ddsp/blob/7cb3c37f96a3e5b4a2b7e94fdcc801bfd556021b/ddsp/core.py#L950-L955
# It mentions some performance penalty.
# In torchaudio, a simple way to work around is to use float64.
# We might add angular_cumsum if it turned out to be undesirable.
pi2 = 2.0 * torch.pi pi2 = 2.0 * torch.pi
freqs = frequencies * pi2 / sample_rate % pi2 freqs = frequencies * pi2 / sample_rate % pi2
phases = torch.cumsum(freqs, dim=-2) phases = torch.cumsum(freqs, dim=-2, dtype=dtype)
if dtype is not None and freqs.dtype != dtype:
phases = phases.to(freqs.dtype)
waveform = amplitudes * torch.sin(phases) waveform = amplitudes * torch.sin(phases)
if reduction == "sum": if reduction == "sum":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment