Unverified Commit 560c082e authored by moto's avatar moto Committed by GitHub
Browse files

[Fbsync] Lint fix (#1726)

parent 4915524f
...@@ -253,4 +253,4 @@ python inference.py --checkpoint-path ${model_path} \ ...@@ -253,4 +253,4 @@ python inference.py --checkpoint-path ${model_path} \
--input-text "Hello world!" \ --input-text "Hello world!" \
--text-preprocessor english_characters \ --text-preprocessor english_characters \
--output-path "./outputs.wav" --output-path "./outputs.wav"
``` ```
\ No newline at end of file
...@@ -5,6 +5,9 @@ import numpy as np ...@@ -5,6 +5,9 @@ import numpy as np
from torchaudio.functional import rnnt_loss from torchaudio.functional import rnnt_loss
CPU_DEVICE = torch.device("cpu")
class _NumpyTransducer(torch.autograd.Function): class _NumpyTransducer(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
...@@ -240,7 +243,7 @@ def get_basic_data(device): ...@@ -240,7 +243,7 @@ def get_basic_data(device):
def get_B1_T10_U3_D4_data( def get_B1_T10_U3_D4_data(
random=False, random=False,
dtype=torch.float32, dtype=torch.float32,
device=torch.device("cpu"), device=CPU_DEVICE,
): ):
B, T, U, D = 2, 10, 3, 4 B, T, U, D = 2, 10, 3, 4
...@@ -263,7 +266,7 @@ def get_B1_T10_U3_D4_data( ...@@ -263,7 +266,7 @@ def get_B1_T10_U3_D4_data(
return data return data
def get_B1_T2_U3_D5_data(dtype=torch.float32, device=torch.device("cpu")): def get_B1_T2_U3_D5_data(dtype=torch.float32, device=CPU_DEVICE):
logits = torch.tensor( logits = torch.tensor(
[ [
0.1, 0.1,
...@@ -360,7 +363,7 @@ def get_B1_T2_U3_D5_data(dtype=torch.float32, device=torch.device("cpu")): ...@@ -360,7 +363,7 @@ def get_B1_T2_U3_D5_data(dtype=torch.float32, device=torch.device("cpu")):
return data, ref_costs, ref_gradients return data, ref_costs, ref_gradients
def get_B2_T4_U3_D3_data(dtype=torch.float32, device=torch.device("cpu")): def get_B2_T4_U3_D3_data(dtype=torch.float32, device=CPU_DEVICE):
# Test from D21322854 # Test from D21322854
logits = torch.tensor( logits = torch.tensor(
[ [
...@@ -550,7 +553,7 @@ def get_random_data( ...@@ -550,7 +553,7 @@ def get_random_data(
max_D=40, max_D=40,
blank=-1, blank=-1,
dtype=torch.float32, dtype=torch.float32,
device=torch.device("cpu"), device=CPU_DEVICE,
seed=None, seed=None,
): ):
if seed is not None: if seed is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment