Unverified Commit 560c082e authored by moto's avatar moto Committed by GitHub
Browse files

[Fbsync] Lint fix (#1726)

parent 4915524f
......@@ -253,4 +253,4 @@ python inference.py --checkpoint-path ${model_path} \
--input-text "Hello world!" \
--text-preprocessor english_characters \
--output-path "./outputs.wav"
```
\ No newline at end of file
```
......@@ -5,6 +5,9 @@ import numpy as np
from torchaudio.functional import rnnt_loss
CPU_DEVICE = torch.device("cpu")
class _NumpyTransducer(torch.autograd.Function):
@staticmethod
def forward(
......@@ -240,7 +243,7 @@ def get_basic_data(device):
def get_B1_T10_U3_D4_data(
random=False,
dtype=torch.float32,
device=torch.device("cpu"),
device=CPU_DEVICE,
):
B, T, U, D = 2, 10, 3, 4
......@@ -263,7 +266,7 @@ def get_B1_T10_U3_D4_data(
return data
def get_B1_T2_U3_D5_data(dtype=torch.float32, device=torch.device("cpu")):
def get_B1_T2_U3_D5_data(dtype=torch.float32, device=CPU_DEVICE):
logits = torch.tensor(
[
0.1,
......@@ -360,7 +363,7 @@ def get_B1_T2_U3_D5_data(dtype=torch.float32, device=torch.device("cpu")):
return data, ref_costs, ref_gradients
def get_B2_T4_U3_D3_data(dtype=torch.float32, device=torch.device("cpu")):
def get_B2_T4_U3_D3_data(dtype=torch.float32, device=CPU_DEVICE):
# Test from D21322854
logits = torch.tensor(
[
......@@ -550,7 +553,7 @@ def get_random_data(
max_D=40,
blank=-1,
dtype=torch.float32,
device=torch.device("cpu"),
device=CPU_DEVICE,
seed=None,
):
if seed is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment