Commit 27aa52fb authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix style issue (#3410)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3410

Differential Revision: D46496786

Pulled By: mthrok

fbshipit-source-id: e517b273c40b340f39ce7db7ab1be1c3eb5f2059
parent 23e756af
......@@ -27,5 +27,5 @@ def ensemble(args):
os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt")
for n in range(args.epochs - 10, args.epochs)
]
model_path = os.path.join(args.exp_dir, args.experiment_name, f"model_avg_10.pth")
model_path = os.path.join(args.exp_dir, args.experiment_name, "model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path)
import os
import random
import torch
import torchaudio
from lrs3 import LRS3
from pytorch_lightning import LightningDataModule
......
import itertools
import logging
import math
from collections import namedtuple
......
import itertools
import logging
import math
from collections import namedtuple
......
import os
from pathlib import Path
from typing import Tuple, Union
import torch
import torchaudio
import torchvision
from torch import Tensor
from torch.utils.data import Dataset
......@@ -15,7 +11,7 @@ def _load_list(args, *filenames):
for filename in filenames:
filepath = os.path.join(os.path.dirname(args.dataset_path), filename)
for line in open(filepath).read().splitlines():
dataset_name, rel_path, input_length = line.split(",")[0], line.split(",")[1], line.split(",")[2]
rel_path, input_length = line.split(",")[1:3]
path = os.path.normpath(os.path.join(args.dataset_path, rel_path[:-4] + ".mp4"))
length.append(int(input_length))
output.append(path)
......
from torchaudio.models.rnnt import emformer_rnnt_model
# https://pytorch.org/audio/master/_modules/torchaudio/models/rnnt.html#emformer_rnnt_base
def emformer_rnnt():
return emformer_rnnt_model(
......
import logging
import os
import pathlib
from argparse import ArgumentParser
import sentencepiece as spm
......@@ -19,7 +18,7 @@ def get_trainer(args):
monitor="monitoring_step",
mode="max",
save_last=True,
filename=f"{{epoch}}",
filename="{{epoch}}",
save_top_k=10,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
......
import json
import math
import random
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
import torchvision
from data_module import LRS3DataModule
from lightning import Batch
......
......@@ -2,7 +2,6 @@ import itertools
from functools import partial
import torch
import torchaudio
from parameterized import parameterized
from torchaudio._backend.utils import get_load_func
from torchaudio_unittest.common_utils import (
......
import io
import itertools
import tarfile
import torch
import torchaudio
from parameterized import parameterized
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import (
get_asset_path,
get_wav_data,
HttpServerMixin,
load_wav,
nested_params,
PytorchTestCase,
save_wav,
skipIfNoExec,
skipIfNoModule,
skipIfNoSox,
sox_utils,
TempDirMixin,
......@@ -25,10 +19,6 @@ from torchaudio_unittest.common_utils import (
from .common import name_func
if _mod_utils.is_module_available("requests"):
import requests
class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_format(
self,
......
import io
import itertools
from parameterized import parameterized
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment