Unverified Commit e3247e30 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Fix flanger and phaser for GPU (#702)



* Make `_generate_wave_table` device and dtype aware.
Signed-off-by: default avatarBhargav Kathivarapu <bhargavkathivarapu31@gmail.com>
parent 0bd91484
......@@ -1317,7 +1317,6 @@ def phaser(
delay_buf = torch.zeros(waveform.shape[0], delay_buf_len, dtype=dtype, device=device)
mod_buf_len = int(sample_rate / mod_speed + .5)
mod_buf = torch.zeros(mod_buf_len, dtype=dtype, device=device)
if sinusoidal:
wave_type = 'SINE'
......@@ -1329,7 +1328,8 @@ def phaser(
table_size=mod_buf_len,
min=1.,
max=float(delay_buf_len),
phase=math.pi / 2)
phase=math.pi / 2,
device=device)
delay_pos = 0
mod_pos = 0
......@@ -1353,7 +1353,8 @@ def _generate_wave_table(
table_size: int,
min: float,
max: float,
phase: float
phase: float,
device: torch.device
) -> Tensor:
r"""A helper fucntion for phaser. Generates a table with given parameters
......@@ -1364,18 +1365,18 @@ def _generate_wave_table(
min (float): desired min value
max (float): desired max value
phase (float): desired phase
device (torch.device): Torch device on which table must be generated
Returns:
Tensor: A 1D tensor with wave table values
"""
phase_offset = int(phase / math.pi / 2 * table_size + 0.5)
t = torch.arange(table_size).to(torch.int32)
t = torch.arange(table_size, device=device, dtype=torch.int32)
point = (t + phase_offset) % table_size
d = torch.zeros_like(point).to(torch.float64)
d = torch.zeros_like(point, device=device, dtype=torch.float64)
if wave_type == 'SINE':
d = (torch.sin(point.to(torch.float64) / table_size * 2 * math.pi) + 1) / 2
......@@ -1487,8 +1488,6 @@ def flanger(
lfo_length = int(sample_rate / speed)
lfo = torch.zeros(lfo_length, dtype=dtype, device=device)
table_min = math.floor(delay_min * sample_rate + 0.5)
table_max = delay_buf_length - 2.
......@@ -1497,13 +1496,14 @@ def flanger(
table_size=lfo_length,
min=float(table_min),
max=float(table_max),
phase=3 * math.pi / 2)
phase=3 * math.pi / 2,
device=device)
output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)
delay_buf_pos = 0
lfo_pos = 0
channel_idxs = torch.arange(0, n_channels)
channel_idxs = torch.arange(0, n_channels, device=device)
for i in range(waveform.shape[-1]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment