Unverified Commit 9078c0b9 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add style checks on example files on CI (#1667)

parent 16f3b2f9
...@@ -29,7 +29,7 @@ set +e ...@@ -29,7 +29,7 @@ set +e
exit_status=0 exit_status=0
printf "\x1b[34mRunning flake8:\x1b[0m\n" printf "\x1b[34mRunning flake8:\x1b[0m\n"
flake8 torchaudio test build_tools/setup_helpers docs/source/conf.py flake8 torchaudio test build_tools/setup_helpers docs/source/conf.py examples
status=$? status=$?
exit_status="$((exit_status+status))" exit_status="$((exit_status+status))"
if [ "${status}" -ne 0 ]; then if [ "${status}" -ne 0 ]; then
......
from . import utils, vad from . import utils, vad
__all__ = ['utils', 'vad']
...@@ -11,7 +11,7 @@ example: python parse_voxforge.py voxforge/de/Helge-20150608-aku ...@@ -11,7 +11,7 @@ example: python parse_voxforge.py voxforge/de/Helge-20150608-aku
... ...
Dataset can be obtained from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/ Dataset can be obtained from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/
""" """ # noqa: E501
import os import os
import argparse import argparse
from pathlib import Path from pathlib import Path
......
...@@ -2,3 +2,5 @@ from . import ( ...@@ -2,3 +2,5 @@ from . import (
train, train,
trainer, trainer,
) )
__all__ = ['train', 'trainer']
...@@ -63,7 +63,7 @@ def _parse_args(args): ...@@ -63,7 +63,7 @@ def _parse_args(args):
group.add_argument( group.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
help=f"Batch size. (default: 16 // world_size)", help="Batch size. (default: 16 // world_size)",
) )
group = parser.add_argument_group("Training Options") group = parser.add_argument_group("Training Options")
group.add_argument( group.add_argument(
...@@ -223,7 +223,7 @@ def train(args): ...@@ -223,7 +223,7 @@ def train(args):
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
else: else:
dist_utils.synchronize_params( dist_utils.synchronize_params(
str(args.save_dir / f"tmp.pt"), device, model, optimizer str(args.save_dir / "tmp.pt"), device, model, optimizer
) )
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
...@@ -258,7 +258,7 @@ def train(args): ...@@ -258,7 +258,7 @@ def train(args):
debug=args.debug, debug=args.debug,
) )
log_path = args.save_dir / f"log.csv" log_path = args.save_dir / "log.csv"
_write_header(log_path, args) _write_header(log_path, args)
dist_utils.write_csv_on_master( dist_utils.write_csv_on_master(
log_path, log_path,
......
...@@ -3,3 +3,5 @@ from . import ( ...@@ -3,3 +3,5 @@ from . import (
dist_utils, dist_utils,
metrics, metrics,
) )
__all__ = ['dataset', 'dist_utils', 'metrics']
from . import utils, wsj0mix from . import utils, wsj0mix
__all__ = ['utils', 'wsj0mix']
...@@ -43,7 +43,6 @@ def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, random_s ...@@ -43,7 +43,6 @@ def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, random_s
return mix, src, mask return mix, src, mask
def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration): def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
target_num_frames = int(duration * sample_rate) target_num_frames = int(duration * sample_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment