"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "1bdab9fdb19f8a8c73ed85291f9acea5bc1c7075"
Unverified Commit d41d30ab authored by moto's avatar moto Committed by GitHub
Browse files

Separate CPU and GPU tests for functions torchscript test + Fix devices in two functionals (#528)



* Separate CPU and GPU tests for functions torchscript test

* fix indentation
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent d47f42a3
This diff is collapsed.
......@@ -1210,13 +1210,16 @@ def mask_along_axis_iid(
if axis != 2 and axis != 3:
raise ValueError('Only Frequency and Time masking are supported')
value = torch.rand(specgrams.shape[:2]) * mask_param
min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value)
device = specgrams.device
dtype = specgrams.dtype
value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
# Create broadcastable mask
mask_start = (min_value.long())[..., None, None].float()
mask_end = (min_value.long() + value.long())[..., None, None].float()
mask = torch.arange(0, specgrams.size(axis)).float()
mask_start = min_value[..., None, None]
mask_end = (min_value + value)[..., None, None]
mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
# Per batch example masking
specgrams = specgrams.transpose(axis, -1)
......@@ -1298,6 +1301,8 @@ def compute_deltas(
>>> delta = compute_deltas(specgram)
>>> delta2 = compute_deltas(delta)
"""
device = specgram.device
dtype = specgram.dtype
# pack batch
shape = specgram.size()
......@@ -1312,7 +1317,7 @@ def compute_deltas(
specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
kernel = (torch.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype).repeat(specgram.shape[1], 1, 1))
kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
......@@ -1431,7 +1436,7 @@ def _apply_probability_distribution(
signal_scaled_dis = signal_scaled + gaussian
else:
# dtype needed for https://github.com/pytorch/pytorch/issues/32358
TPDF = torch.bartlett_window(time_size + 1, dtype=torch.float)
TPDF = torch.bartlett_window(time_size + 1, dtype=signal_scaled.dtype, device=signal_scaled.device)
TPDF = TPDF.repeat((channel_size + 1), 1)
signal_scaled_dis = signal_scaled + TPDF
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment