Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
c6d6256b
Commit
c6d6256b
authored
Nov 11, 2017
by
Myle Ott
Browse files
Add `--log-format` option and JSON logger
parent
50fdf591
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
204 additions
and
86 deletions
+204
-86
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/progress_bar.py
fairseq/progress_bar.py
+149
-42
fairseq/utils.py
fairseq/utils.py
+13
-1
generate.py
generate.py
+4
-5
train.py
train.py
+36
-38
No files found.
fairseq/options.py
View file @
c6d6256b
...
@@ -18,6 +18,8 @@ def get_parser(desc):
...
@@ -18,6 +18,8 @@ def get_parser(desc):
parser
.
add_argument
(
'--no-progress-bar'
,
action
=
'store_true'
,
help
=
'disable progress bar'
)
parser
.
add_argument
(
'--no-progress-bar'
,
action
=
'store_true'
,
help
=
'disable progress bar'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
help
=
'log progress every N updates (when progress bar is disabled)'
)
help
=
'log progress every N updates (when progress bar is disabled)'
)
parser
.
add_argument
(
'--log-format'
,
default
=
'tqdm'
,
help
=
'log format to use'
,
choices
=
[
'json'
,
'none'
,
'simple'
,
'tqdm'
])
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'pseudo random number generator seed'
)
help
=
'pseudo random number generator seed'
)
return
parser
return
parser
...
...
fairseq/progress_bar.py
View file @
c6d6256b
...
@@ -7,35 +7,29 @@
...
@@ -7,35 +7,29 @@
#
#
"""
"""
Progress bar wrapper around tqdm which handles non-TTY outputs
.
Wrapper around various loggers and progress bars (e.g., tqdm)
.
"""
"""
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
json
from
numbers
import
Number
from
numbers
import
Number
import
sys
import
sys
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
fairseq.meters
import
AverageMeter
class
progress_bar
(
tqdm
):
enabled
=
sys
.
stderr
.
isatty
()
print_interval
=
1000
def
__new__
(
cls
,
*
args
,
**
kwargs
):
class
progress_bar
(
object
):
if
cls
.
enabled
:
"""Abstract class for progress bars."""
return
tqdm
(
*
args
,
**
kwargs
)
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
else
:
return
simple_progress_bar
(
cls
.
print_interval
,
*
args
,
**
kwargs
)
class
simple_progress_bar
(
object
):
"""A minimal replacement for tqdm in non-TTY environments."""
def
__init__
(
self
,
print_interval
,
iterable
,
desc
=
None
,
*
_args
,
**
_kwargs
):
super
().
__init__
()
self
.
print_interval
=
print_interval
self
.
iterable
=
iterable
self
.
iterable
=
iterable
self
.
desc
=
desc
self
.
epoch
=
epoch
self
.
prefix
=
''
if
epoch
is
not
None
:
self
.
prefix
+=
f
'| epoch
{
epoch
:
03
d
}
'
if
prefix
is
not
None
:
self
.
prefix
+=
f
' |
{
prefix
}
'
def
__enter__
(
self
):
def
__enter__
(
self
):
return
self
return
self
...
@@ -44,36 +38,149 @@ class simple_progress_bar(object):
...
@@ -44,36 +38,149 @@ class simple_progress_bar(object):
return
False
return
False
def
__iter__
(
self
):
def
__iter__
(
self
):
size
=
len
(
self
.
iterable
)
raise
NotImplementedError
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
def
log
(
self
,
stats
):
if
i
>
0
and
i
%
self
.
print_interval
==
0
:
"""Log intermediate stats according to log_interval."""
desc
=
''
if
self
.
desc
is
None
else
'{}: '
.
format
(
self
.
desc
)
raise
NotImplementedError
msg
=
'{}{:5d} / {:d} {}
\n
'
.
format
(
desc
,
i
,
size
,
self
.
postfix
)
sys
.
stdout
.
write
(
msg
)
def
print
(
self
,
stats
):
sys
.
stdout
.
flush
()
"""Print end-of-epoch stats."""
raise
NotImplementedError
def
set_postfix
(
self
,
ordered_dict
=
None
,
refresh
=
True
,
**
kwargs
):
def
_str_commas
(
self
,
stats
):
# Sort in alphabetical order to be more deterministic
return
', '
.
join
(
key
+
'='
+
stats
[
key
].
strip
()
postfix
=
OrderedDict
([]
if
ordered_dict
is
None
else
ordered_dict
)
for
key
in
stats
.
keys
())
for
key
in
sorted
(
kwargs
.
keys
()):
postfix
[
key
]
=
kwargs
[
key
]
def
_str_pipes
(
self
,
stats
):
return
' | '
.
join
(
key
+
' '
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_format_stats
(
self
,
stats
):
postfix
=
OrderedDict
(
stats
)
# Preprocess stats according to datatype
# Preprocess stats according to datatype
for
key
in
postfix
.
keys
():
for
key
in
postfix
.
keys
():
# Number: limit the length of the string
# Number: limit the length of the string
if
isinstance
(
postfix
[
key
],
Number
):
if
isinstance
(
postfix
[
key
],
Number
):
postfix
[
key
]
=
'{0:2.3g}'
.
format
(
postfix
[
key
])
postfix
[
key
]
=
'{:g}'
.
format
(
postfix
[
key
])
# Meter: display both current and average value
elif
isinstance
(
postfix
[
key
],
AverageMeter
):
postfix
[
key
]
=
'{:.2f} ({:.2f})'
.
format
(
postfix
[
key
].
val
,
postfix
[
key
].
avg
)
# Else for any other type, try to get the string conversion
# Else for any other type, try to get the string conversion
elif
not
isinstance
(
postfix
[
key
],
str
):
elif
not
isinstance
(
postfix
[
key
],
str
):
postfix
[
key
]
=
str
(
postfix
[
key
])
postfix
[
key
]
=
str
(
postfix
[
key
])
# Else if it's a string, don't need to preprocess anything
# Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix
return
postfix
self
.
postfix
=
', '
.
join
(
key
+
'='
+
postfix
[
key
].
strip
()
for
key
in
postfix
.
keys
())
class
json_progress_bar
(
progress_bar
):
@
classmethod
"""Log output in JSON format."""
def
write
(
cls
,
s
,
file
=
None
,
end
=
"
\n
"
):
fp
=
file
if
file
is
not
None
else
sys
.
stdout
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
fp
.
write
(
s
)
super
().
__init__
(
iterable
,
epoch
,
prefix
)
fp
.
write
(
end
)
self
.
log_interval
=
log_interval
fp
.
flush
()
self
.
postfix_json
=
None
def
__iter__
(
self
):
size
=
float
(
len
(
self
.
iterable
))
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
if
self
.
stats
is
not
None
and
i
>
0
and
\
self
.
log_interval
is
not
None
and
i
%
self
.
log_interval
==
0
:
update
=
self
.
epoch
+
float
(
i
/
size
)
if
self
.
epoch
is
not
None
else
None
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
,
update
=
update
)
print
(
"sweep_log: "
+
json
.
dumps
(
stats
))
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
stats
=
stats
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
)
print
(
"sweep_log: "
+
json
.
dumps
(
stats
))
def
_format_stats
(
self
,
stats
,
epoch
=
None
,
update
=
None
):
postfix
=
OrderedDict
()
if
epoch
is
not
None
:
postfix
[
'epoch'
]
=
epoch
if
update
is
not
None
:
postfix
[
'update'
]
=
update
# Preprocess stats according to datatype
for
key
in
stats
.
keys
():
# Meter: display both current and average value
if
isinstance
(
stats
[
key
],
AverageMeter
):
postfix
[
key
]
=
stats
[
key
].
val
postfix
[
key
+
'_avg'
]
=
stats
[
key
].
avg
else
:
postfix
[
key
]
=
stats
[
key
]
return
postfix
class
noop_progress_bar
(
progress_bar
):
"""No logging."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
def
__iter__
(
self
):
for
obj
in
self
.
iterable
:
yield
obj
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
pass
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
pass
class
simple_progress_bar
(
progress_bar
):
"""A minimal logger for non-TTY environments."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
stats
=
None
def
__iter__
(
self
):
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
if
self
.
stats
is
not
None
and
i
>
0
and
\
self
.
log_interval
is
not
None
and
i
%
self
.
log_interval
==
0
:
postfix
=
self
.
_str_commas
(
self
.
stats
)
print
(
f
'
{
self
.
prefix
}
:
{
i
:
5
d
}
/
{
size
:
d
}
{
postfix
}
'
)
sys
.
stdout
.
flush
()
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
stats
=
self
.
_format_stats
(
stats
)
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
print
(
f
'
{
self
.
prefix
}
|
{
postfix
}
'
)
sys
.
stdout
.
flush
()
class
tqdm_progress_bar
(
progress_bar
):
"""Log to tqdm."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
tqdm
=
tqdm
(
iterable
,
self
.
prefix
,
leave
=
False
)
def
__iter__
(
self
):
return
iter
(
self
.
tqdm
)
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
tqdm
.
set_postfix
(
self
.
_format_stats
(
stats
),
refresh
=
False
)
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
self
.
tqdm
.
write
(
f
'
{
self
.
tqdm
.
desc
}
|
{
postfix
}
'
)
fairseq/utils.py
View file @
c6d6256b
...
@@ -14,7 +14,7 @@ import traceback
...
@@ -14,7 +14,7 @@ import traceback
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
torch.serialization
import
default_restore_location
from
torch.serialization
import
default_restore_location
from
fairseq
import
criterions
,
data
,
models
,
tokenizer
from
fairseq
import
criterions
,
data
,
models
,
progress_bar
,
tokenizer
def
parse_args_and_arch
(
parser
):
def
parse_args_and_arch
(
parser
):
...
@@ -36,6 +36,18 @@ def build_criterion(args, src_dict, dst_dict):
...
@@ -36,6 +36,18 @@ def build_criterion(args, src_dict, dst_dict):
return
criterions
.
CrossEntropyCriterion
(
args
,
dst_dict
)
return
criterions
.
CrossEntropyCriterion
(
args
,
dst_dict
)
def
build_progress_bar
(
args
,
iterator
,
epoch
=
None
,
prefix
=
None
):
if
args
.
log_format
==
'json'
:
bar
=
progress_bar
.
json_progress_bar
(
iterator
,
epoch
,
prefix
,
args
.
log_interval
)
elif
args
.
log_format
==
'none'
:
bar
=
progress_bar
.
noop_progress_bar
(
iterator
,
epoch
,
prefix
)
elif
args
.
log_format
==
'tqdm'
:
bar
=
progress_bar
.
tqdm_progress_bar
(
iterator
,
epoch
,
prefix
)
else
:
bar
=
progress_bar
.
simple_progress_bar
(
iterator
,
epoch
,
prefix
,
args
.
log_interval
)
return
bar
def
torch_persistent_save
(
*
args
,
**
kwargs
):
def
torch_persistent_save
(
*
args
,
**
kwargs
):
for
i
in
range
(
3
):
for
i
in
range
(
3
):
try
:
try
:
...
...
generate.py
View file @
c6d6256b
...
@@ -11,7 +11,6 @@ import torch
...
@@ -11,7 +11,6 @@ import torch
from
fairseq
import
bleu
,
data
,
options
,
tokenizer
,
utils
from
fairseq
import
bleu
,
data
,
options
,
tokenizer
,
utils
from
fairseq.meters
import
StopwatchMeter
,
TimeMeter
from
fairseq.meters
import
StopwatchMeter
,
TimeMeter
from
fairseq.progress_bar
import
progress_bar
from
fairseq.sequence_generator
import
SequenceGenerator
from
fairseq.sequence_generator
import
SequenceGenerator
...
@@ -27,10 +26,10 @@ def main():
...
@@ -27,10 +26,10 @@ def main():
options
.
add_generation_args
(
parser
)
options
.
add_generation_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
no_progress_bar
:
args
.
log_format
=
'none'
print
(
args
)
print
(
args
)
if
args
.
no_progress_bar
:
progress_bar
.
enabled
=
False
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dataset
# Load dataset
...
@@ -74,7 +73,7 @@ def main():
...
@@ -74,7 +73,7 @@ def main():
args
.
gen_subset
,
max_sentences
=
args
.
batch_size
,
max_positions
=
max_positions
,
args
.
gen_subset
,
max_sentences
=
args
.
batch_size
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
num_sentences
=
0
num_sentences
=
0
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
with
utils
.
build_
progress_bar
(
args
,
itr
)
as
t
:
wps_meter
=
TimeMeter
()
wps_meter
=
TimeMeter
()
gen_timer
=
StopwatchMeter
()
gen_timer
=
StopwatchMeter
()
translations
=
translator
.
generate_batched_itr
(
translations
=
translator
.
generate_batched_itr
(
...
@@ -119,7 +118,7 @@ def main():
...
@@ -119,7 +118,7 @@ def main():
scorer
.
add
(
target_tokens
,
hypo_tokens
)
scorer
.
add
(
target_tokens
,
hypo_tokens
)
wps_meter
.
update
(
src_tokens
.
size
(
0
))
wps_meter
.
update
(
src_tokens
.
size
(
0
))
t
.
set_postfix
(
wps
=
'{:5d}'
.
format
(
round
(
wps_meter
.
avg
)
),
refresh
=
False
)
t
.
log
({
'wps'
:
round
(
wps_meter
.
avg
)
}
)
num_sentences
+=
1
num_sentences
+=
1
print
(
'| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'
.
format
(
print
(
'| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'
.
format
(
...
...
train.py
View file @
c6d6256b
...
@@ -15,7 +15,6 @@ import math
...
@@ -15,7 +15,6 @@ import math
from
fairseq
import
data
,
options
,
utils
from
fairseq
import
data
,
options
,
utils
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
,
TimeMeter
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
,
TimeMeter
from
fairseq.multiprocessing_trainer
import
MultiprocessingTrainer
from
fairseq.multiprocessing_trainer
import
MultiprocessingTrainer
from
fairseq.progress_bar
import
progress_bar
def
main
():
def
main
():
...
@@ -38,8 +37,7 @@ def main():
...
@@ -38,8 +37,7 @@ def main():
args
=
utils
.
parse_args_and_arch
(
parser
)
args
=
utils
.
parse_args_and_arch
(
parser
)
if
args
.
no_progress_bar
:
if
args
.
no_progress_bar
:
progress_bar
.
enabled
=
False
args
.
log_format
=
'simple'
progress_bar
.
print_interval
=
args
.
log_interval
if
not
os
.
path
.
exists
(
args
.
save_dir
):
if
not
os
.
path
.
exists
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
os
.
makedirs
(
args
.
save_dir
)
...
@@ -124,7 +122,7 @@ def main():
...
@@ -124,7 +122,7 @@ def main():
def
get_perplexity
(
loss
):
def
get_perplexity
(
loss
):
try
:
try
:
return
math
.
pow
(
2
,
loss
)
return
round
(
math
.
pow
(
2
,
loss
)
,
2
)
except
OverflowError
:
except
OverflowError
:
return
float
(
'inf'
)
return
float
(
'inf'
)
...
@@ -149,9 +147,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
...
@@ -149,9 +147,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
clip_meter
=
AverageMeter
()
# % of updates clipped
clip_meter
=
AverageMeter
()
# % of updates clipped
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
desc
=
'| epoch {:03d}'
.
format
(
epoch
)
lr
=
trainer
.
get_lr
()
lr
=
trainer
.
get_lr
()
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
with
utils
.
build_
progress_bar
(
args
,
itr
,
epoch
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
loss_dict
=
trainer
.
train_step
(
sample
)
loss_dict
=
trainer
.
train_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
loss
=
loss_dict
[
'loss'
]
...
@@ -168,16 +165,16 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
...
@@ -168,16 +165,16 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
extra_postfix
=
[]
extra_postfix
=
[]
for
k
,
v
in
loss_dict
.
items
():
for
k
,
v
in
loss_dict
.
items
():
extra_meters
[
k
].
update
(
v
)
extra_meters
[
k
].
update
(
v
)
extra_postfix
.
append
((
k
,
'{:.4f}'
.
format
(
extra_meters
[
k
].
avg
))
)
extra_postfix
.
append
((
k
,
extra_meters
[
k
].
avg
))
t
.
set_postfix
(
collections
.
OrderedDict
([
t
.
log
(
collections
.
OrderedDict
([
(
'loss'
,
'{:.2f} ({:.2f})'
.
format
(
loss
,
loss_meter
.
avg
)
),
(
'loss'
,
loss_meter
),
(
'wps'
,
'{:5d}'
.
format
(
round
(
wps_meter
.
avg
))
)
,
(
'wps'
,
round
(
wps_meter
.
avg
)),
(
'wpb'
,
'{:5d}'
.
format
(
round
(
wpb_meter
.
avg
))
)
,
(
'wpb'
,
round
(
wpb_meter
.
avg
)),
(
'bsz'
,
'{:5d}'
.
format
(
round
(
bsz_meter
.
avg
))
)
,
(
'bsz'
,
round
(
bsz_meter
.
avg
)),
(
'lr'
,
lr
),
(
'lr'
,
lr
),
(
'clip'
,
'{:
3
.0
f}
%'
.
format
(
clip_meter
.
avg
*
100
)),
(
'clip'
,
'{:.0%
}
'
.
format
(
clip_meter
.
avg
)),
]
+
extra_postfix
)
,
refresh
=
False
)
]
+
extra_postfix
))
if
i
==
0
:
if
i
==
0
:
# ignore the first mini-batch in words-per-second calculation
# ignore the first mini-batch in words-per-second calculation
...
@@ -185,17 +182,19 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
...
@@ -185,17 +182,19 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
save_checkpoint
(
trainer
,
args
,
epoch
,
i
+
1
)
save_checkpoint
(
trainer
,
args
,
epoch
,
i
+
1
)
fmt
=
desc
+
' | train loss {:2.2f} | train ppl {:3.2f}'
.
format
(
t
.
print
(
collections
.
OrderedDict
([
loss_meter
.
avg
,
get_perplexity
(
loss_meter
.
avg
))
(
'train loss'
,
round
(
loss_meter
.
avg
,
2
)),
fmt
+=
' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
.
format
(
(
'train ppl'
,
get_perplexity
(
loss_meter
.
avg
)),
round
(
wps_meter
.
elapsed_time
),
round
(
wps_meter
.
avg
),
round
(
wpb_meter
.
avg
))
(
's/checkpoint'
,
round
(
wps_meter
.
elapsed_time
)),
fmt
+=
' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'
.
format
(
(
'words/s'
,
round
(
wps_meter
.
avg
)),
round
(
bsz_meter
.
avg
),
lr
,
clip_meter
.
avg
*
100
)
(
'words/batch'
,
round
(
wpb_meter
.
avg
)),
fmt
+=
''
.
join
(
(
'bsz'
,
round
(
bsz_meter
.
avg
)),
' | {} {:.4f}'
.
format
(
k
,
meter
.
avg
)
(
'lr'
,
lr
),
(
'clip'
,
'{:3.0f}%'
.
format
(
clip_meter
.
avg
*
100
)),
]
+
[
(
k
,
meter
.
avg
)
for
k
,
meter
in
extra_meters
.
items
()
for
k
,
meter
in
extra_meters
.
items
()
)
]))
t
.
write
(
fmt
)
def
save_checkpoint
(
trainer
,
args
,
epoch
,
batch_offset
,
val_loss
):
def
save_checkpoint
(
trainer
,
args
,
epoch
,
batch_offset
,
val_loss
):
...
@@ -232,8 +231,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
...
@@ -232,8 +231,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
desc
=
'| epoch {:03d} |
valid on
\'
{}
\'
subset'
.
format
(
epoch
,
subset
)
prefix
=
'
valid on
\'
{}
\'
subset'
.
format
(
subset
)
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
with
utils
.
build_
progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
ngpus
):
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
ngpus
):
loss_dict
=
trainer
.
valid_step
(
sample
)
loss_dict
=
trainer
.
valid_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
loss
=
loss_dict
[
'loss'
]
...
@@ -245,23 +244,22 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
...
@@ -245,23 +244,22 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
extra_postfix
=
[]
extra_postfix
=
[]
for
k
,
v
in
loss_dict
.
items
():
for
k
,
v
in
loss_dict
.
items
():
extra_meters
[
k
].
update
(
v
)
extra_meters
[
k
].
update
(
v
)
extra_postfix
.
append
((
k
,
'{:.4f}'
.
format
(
extra_meters
[
k
].
avg
))
)
extra_postfix
.
append
((
k
,
extra_meters
[
k
].
avg
))
t
.
set_postfix
(
collections
.
OrderedDict
([
t
.
log
(
collections
.
OrderedDict
([
(
'loss'
,
'{:.2f}'
.
format
(
loss_meter
.
avg
)),
(
'
valid
loss'
,
round
(
loss_meter
.
avg
,
2
)),
]
+
extra_postfix
)
,
refresh
=
False
)
]
+
extra_postfix
))
val_loss
=
loss_meter
.
avg
t
.
print
(
collections
.
OrderedDict
([
fmt
=
desc
+
' | valid loss {:2.2f} | valid ppl {:3.2f}'
.
format
(
(
'valid loss'
,
round
(
loss_meter
.
avg
,
2
)),
val
_loss
,
get_perplexity
(
val_
loss
))
(
'
val
id ppl'
,
get_perplexity
(
loss
_meter
.
avg
))
,
fmt
+=
''
.
join
(
]
+
[
' | {} {:.4f}'
.
format
(
k
,
meter
.
avg
)
(
k
,
meter
.
avg
)
for
k
,
meter
in
extra_meters
.
items
()
for
k
,
meter
in
extra_meters
.
items
()
)
]))
t
.
write
(
fmt
)
# update and return the learning rate
# update and return the learning rate
return
val_
loss
return
loss
_meter
.
avg
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment