Commit f9d38796 authored by Robin Scheibler's avatar Robin Scheibler Committed by Facebook GitHub Bot
Browse files

Fixes examples/source_separation for WSJ0_2mix dataset (#2987)

Summary:
The `examples/source_separation` scripts use inconsistent keyword to indicate the WSJ0_2mix dataset. This PR does the following.

1. Use `wsj0mix` consistently as keyword indicating the WSJ0_2mix dataset
2. Corrects `args.data_dir` to `args.root_dir` in eval.py
3. Modify the parameters of `pytorch_lightning.Trainer` according to latest version (use `accelerator="gpu"` and `devices=args.num_devices`, instead of just `gpus=args.num_devices`)

Pull Request resolved: https://github.com/pytorch/audio/pull/2987

Reviewed By: xiaohui-zhang

Differential Revision: D42536992

Pulled By: nateanl

fbshipit-source-id: 10a80263ad7054b1629d8fa023676b607e633d76
parent 52b6bc3b
...@@ -31,7 +31,7 @@ def _eval(model, data_loader, device): ...@@ -31,7 +31,7 @@ def _eval(model, data_loader, device):
def cli_main(): def cli_main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"]) parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0mix", "librimix"])
parser.add_argument( parser.add_argument(
"--root-dir", "--root-dir",
type=Path, type=Path,
...@@ -79,7 +79,7 @@ def cli_main(): ...@@ -79,7 +79,7 @@ def cli_main():
_, _, eval_loader = _get_dataloader( _, _, eval_loader = _get_dataloader(
args.dataset, args.dataset,
args.data_dir, args.root_dir,
args.num_speakers, args.num_speakers,
args.sample_rate, args.sample_rate,
1, # batch size is set to 1 to avoid masking 1, # batch size is set to 1 to avoid masking
......
...@@ -308,7 +308,7 @@ def _get_dataloader( ...@@ -308,7 +308,7 @@ def _get_dataloader(
def cli_main(): def cli_main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--batch-size", default=6, type=int) parser.add_argument("--batch-size", default=6, type=int)
parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"]) parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0mix", "librimix"])
parser.add_argument( parser.add_argument(
"--root-dir", "--root-dir",
type=Path, type=Path,
...@@ -412,9 +412,10 @@ def cli_main(): ...@@ -412,9 +412,10 @@ def cli_main():
trainer = Trainer( trainer = Trainer(
default_root_dir=args.exp_dir, default_root_dir=args.exp_dir,
max_epochs=args.epochs, max_epochs=args.epochs,
gpus=args.num_gpu,
num_nodes=args.num_node, num_nodes=args.num_node,
accelerator="gpu",
strategy="ddp_find_unused_parameters_false", strategy="ddp_find_unused_parameters_false",
devices=args.num_gpu,
limit_train_batches=1.0, # Useful for fast experiment limit_train_batches=1.0, # Useful for fast experiment
gradient_clip_val=5.0, gradient_clip_val=5.0,
callbacks=callbacks, callbacks=callbacks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment