Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ad899151
"src/vscode:/vscode.git/clone" did not exist on "9d49b45b190bc953eb965abd3d70ec30a799f505"
Unverified
Commit
ad899151
authored
Nov 03, 2021
by
nateanl
Committed by
GitHub
Nov 03, 2021
Browse files
Support multi-node training for source separation pipeline (#1968)
parent
15ab315c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
3 deletions
+15
-3
examples/source_separation/README.md
examples/source_separation/README.md
+8
-3
examples/source_separation/lightning_train.py
examples/source_separation/lightning_train.py
+7
-0
No files found.
examples/source_separation/README.md
View file @
ad899151
...
...
@@ -59,7 +59,10 @@ srun wrapper.sh $@
num_speakers
=
2
this_dir
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
>
/dev/null 2>&1
&&
pwd
)
"
exp_dir
=
"/checkpoint/
${
USER
}
/exp/"
dataset_dir
=
"/dataset/Libri
${
num_speakers
}
mix//wav8k/min"
root_dir
=
"/dataset/"
# The directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.
num_gpu
=
2
# The number of GPUs used on one node.
num_node
=
1
# The number of nodes used on the cluster.
batch_size
=
6
# The batch size per GPU.
mkdir
-p
"
${
exp_dir
}
"
...
...
@@ -68,9 +71,11 @@ python -u \
"
${
this_dir
}
/lightning_train.py"
\
--num-speakers
"
${
num_speakers
}
"
\
--sample-rate
8000
\
--
data
-dir
"
${
datase
t_dir
}
"
\
--
root
-dir
"
${
roo
t_dir
}
"
\
--exp-dir
"
${
exp_dir
}
"
\
--batch-size
$((
16
/
SLURM_NTASKS
))
--num-gpu
${
num_gpu
}
\
--num-node
${
num_node
}
\
--batch-size
${
batch_size
}
\
```
</details>
examples/source_separation/lightning_train.py
View file @
ad899151
...
...
@@ -391,6 +391,12 @@ def cli_main():
type
=
int
,
help
=
"The number of GPUs for training. (default: 1)"
,
)
parser
.
add_argument
(
"--num-node"
,
default
=
1
,
type
=
int
,
help
=
"The number of nodes in the cluster for training. (default: 1)"
,
)
parser
.
add_argument
(
"--num-workers"
,
default
=
4
,
...
...
@@ -447,6 +453,7 @@ def cli_main():
default_root_dir
=
args
.
exp_dir
,
max_epochs
=
args
.
epochs
,
gpus
=
args
.
num_gpu
,
num_nodes
=
args
.
num_node
,
accelerator
=
"ddp"
,
plugins
=
DDPPlugin
(
find_unused_parameters
=
False
),
# make sure there is no unused params
limit_train_batches
=
1.0
,
# Useful for fast experiment
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment