Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
99493a85
Commit
99493a85
authored
Dec 02, 2017
by
Myle Ott
Browse files
Save number of GPUs in args (and checkpoints)
parent
bd46c5ec
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
train.py
train.py
+11
-11
No files found.
train.py
View file @
99493a85
...
@@ -53,18 +53,18 @@ def main():
...
@@ -53,18 +53,18 @@ def main():
# record inferred languages in args, so that it's saved in checkpoints
# record inferred languages in args, so that it's saved in checkpoints
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
args
.
num_gpus
=
torch
.
cuda
.
device_count
()
print
(
args
)
print
(
args
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
for
split
in
splits
:
for
split
in
splits
:
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
split
,
len
(
dataset
.
splits
[
split
])))
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
split
,
len
(
dataset
.
splits
[
split
])))
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
num_gpus
=
torch
.
cuda
.
device_count
()
print
(
'| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'
.
format
(
print
(
'| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
,
args
.
max_sentences
))
args
.
num_gpus
,
args
.
max_tokens
,
args
.
max_sentences
))
# Build model and criterion
# Build model and criterion
model
=
utils
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
model
=
utils
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
...
@@ -102,11 +102,11 @@ def main():
...
@@ -102,11 +102,11 @@ def main():
train_meter
.
start
()
train_meter
.
start
()
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
:
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
:
# train for one epoch
# train for one epoch
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions_train
,
num_gpus
)
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions_train
)
# evaluate on validate set
# evaluate on validate set
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions_valid
,
subset
,
num_gpus
)
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions_valid
,
subset
)
if
k
==
0
:
if
k
==
0
:
if
not
args
.
no_save
:
if
not
args
.
no_save
:
# save checkpoint
# save checkpoint
...
@@ -130,7 +130,7 @@ def get_perplexity(loss):
...
@@ -130,7 +130,7 @@ def get_perplexity(loss):
return
float
(
'inf'
)
return
float
(
'inf'
)
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions
,
num_gpus
):
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions
):
"""Train the model for one epoch."""
"""Train the model for one epoch."""
seed
=
args
.
seed
+
epoch
seed
=
args
.
seed
+
epoch
...
@@ -152,7 +152,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
...
@@ -152,7 +152,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
lr
=
trainer
.
get_lr
()
lr
=
trainer
.
get_lr
()
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
)
as
t
:
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
args
.
num_gpus
,
batch_offset
):
loss_dict
=
trainer
.
train_step
(
sample
)
loss_dict
=
trainer
.
train_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
...
@@ -222,7 +222,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
...
@@ -222,7 +222,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
,
ngpus
):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
):
"""Evaluate the model on the validation set and return the average loss."""
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
eval_dataloader
(
itr
=
dataset
.
eval_dataloader
(
...
@@ -236,7 +236,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
...
@@ -236,7 +236,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
)
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
)
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
n
gpus
):
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
args
.
num_
gpus
):
loss_dict
=
trainer
.
valid_step
(
sample
)
loss_dict
=
trainer
.
valid_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment