Commit 27aa52fb authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix style issue (#3410)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3410

Differential Revision: D46496786

Pulled By: mthrok

fbshipit-source-id: e517b273c40b340f39ce7db7ab1be1c3eb5f2059
parent 23e756af
...@@ -27,5 +27,5 @@ def ensemble(args): ...@@ -27,5 +27,5 @@ def ensemble(args):
os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt") os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt")
for n in range(args.epochs - 10, args.epochs) for n in range(args.epochs - 10, args.epochs)
] ]
model_path = os.path.join(args.exp_dir, args.experiment_name, f"model_avg_10.pth") model_path = os.path.join(args.exp_dir, args.experiment_name, "model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path) torch.save({"state_dict": average_checkpoints(last)}, model_path)
import os
import random import random
import torch import torch
import torchaudio
from lrs3 import LRS3 from lrs3 import LRS3
from pytorch_lightning import LightningDataModule from pytorch_lightning import LightningDataModule
......
import itertools import itertools
import logging
import math import math
from collections import namedtuple from collections import namedtuple
......
import itertools import itertools
import logging
import math import math
from collections import namedtuple from collections import namedtuple
......
import os import os
from pathlib import Path
from typing import Tuple, Union
import torch
import torchaudio import torchaudio
import torchvision import torchvision
from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -15,7 +11,7 @@ def _load_list(args, *filenames): ...@@ -15,7 +11,7 @@ def _load_list(args, *filenames):
for filename in filenames: for filename in filenames:
filepath = os.path.join(os.path.dirname(args.dataset_path), filename) filepath = os.path.join(os.path.dirname(args.dataset_path), filename)
for line in open(filepath).read().splitlines(): for line in open(filepath).read().splitlines():
dataset_name, rel_path, input_length = line.split(",")[0], line.split(",")[1], line.split(",")[2] rel_path, input_length = line.split(",")[1:3]
path = os.path.normpath(os.path.join(args.dataset_path, rel_path[:-4] + ".mp4")) path = os.path.normpath(os.path.join(args.dataset_path, rel_path[:-4] + ".mp4"))
length.append(int(input_length)) length.append(int(input_length))
output.append(path) output.append(path)
......
from torchaudio.models.rnnt import emformer_rnnt_model from torchaudio.models.rnnt import emformer_rnnt_model
# https://pytorch.org/audio/master/_modules/torchaudio/models/rnnt.html#emformer_rnnt_base # https://pytorch.org/audio/master/_modules/torchaudio/models/rnnt.html#emformer_rnnt_base
def emformer_rnnt(): def emformer_rnnt():
return emformer_rnnt_model( return emformer_rnnt_model(
......
import logging import logging
import os import os
import pathlib
from argparse import ArgumentParser from argparse import ArgumentParser
import sentencepiece as spm import sentencepiece as spm
...@@ -19,7 +18,7 @@ def get_trainer(args): ...@@ -19,7 +18,7 @@ def get_trainer(args):
monitor="monitoring_step", monitor="monitoring_step",
mode="max", mode="max",
save_last=True, save_last=True,
filename=f"{{epoch}}", filename="{{epoch}}",
save_top_k=10, save_top_k=10,
) )
lr_monitor = LearningRateMonitor(logging_interval="step") lr_monitor = LearningRateMonitor(logging_interval="step")
......
import json
import math
import random import random
from functools import partial
from typing import List from typing import List
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio
import torchvision import torchvision
from data_module import LRS3DataModule from data_module import LRS3DataModule
from lightning import Batch from lightning import Batch
......
...@@ -2,7 +2,6 @@ import itertools ...@@ -2,7 +2,6 @@ import itertools
from functools import partial from functools import partial
import torch import torch
import torchaudio
from parameterized import parameterized from parameterized import parameterized
from torchaudio._backend.utils import get_load_func from torchaudio._backend.utils import get_load_func
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
......
import io
import itertools import itertools
import tarfile
import torch import torch
import torchaudio
from parameterized import parameterized from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
get_wav_data, get_wav_data,
HttpServerMixin,
load_wav, load_wav,
nested_params, nested_params,
PytorchTestCase, PytorchTestCase,
save_wav, save_wav,
skipIfNoExec, skipIfNoExec,
skipIfNoModule,
skipIfNoSox, skipIfNoSox,
sox_utils, sox_utils,
TempDirMixin, TempDirMixin,
...@@ -25,10 +19,6 @@ from torchaudio_unittest.common_utils import ( ...@@ -25,10 +19,6 @@ from torchaudio_unittest.common_utils import (
from .common import name_func from .common import name_func
if _mod_utils.is_module_available("requests"):
import requests
class LoadTestBase(TempDirMixin, PytorchTestCase): class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_format( def assert_format(
self, self,
......
import io
import os import os
import torch import torch
......
import io
import itertools import itertools
from parameterized import parameterized from parameterized import parameterized
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment