Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
c6d6256b
Commit
c6d6256b
authored
Nov 11, 2017
by
Myle Ott
Browse files
Add `--log-format` option and JSON logger
parent
50fdf591
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
204 additions
and
86 deletions
+204
-86
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/progress_bar.py
fairseq/progress_bar.py
+149
-42
fairseq/utils.py
fairseq/utils.py
+13
-1
generate.py
generate.py
+4
-5
train.py
train.py
+36
-38
No files found.
fairseq/options.py
View file @
c6d6256b
...
...
@@ -18,6 +18,8 @@ def get_parser(desc):
parser
.
add_argument
(
'--no-progress-bar'
,
action
=
'store_true'
,
help
=
'disable progress bar'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
help
=
'log progress every N updates (when progress bar is disabled)'
)
parser
.
add_argument
(
'--log-format'
,
default
=
'tqdm'
,
help
=
'log format to use'
,
choices
=
[
'json'
,
'none'
,
'simple'
,
'tqdm'
])
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'pseudo random number generator seed'
)
return
parser
...
...
fairseq/progress_bar.py
View file @
c6d6256b
...
...
@@ -7,35 +7,29 @@
#
"""
Progress bar wrapper around tqdm which handles non-TTY outputs
.
Wrapper around various loggers and progress bars (e.g., tqdm)
.
"""
from
collections
import
OrderedDict
import
json
from
numbers
import
Number
import
sys
from
tqdm
import
tqdm
from
fairseq.meters
import
AverageMeter
class
progress_bar
(
tqdm
):
enabled
=
sys
.
stderr
.
isatty
()
print_interval
=
1000
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
enabled
:
return
tqdm
(
*
args
,
**
kwargs
)
else
:
return
simple_progress_bar
(
cls
.
print_interval
,
*
args
,
**
kwargs
)
class
simple_progress_bar
(
object
):
"""A minimal replacement for tqdm in non-TTY environments."""
def
__init__
(
self
,
print_interval
,
iterable
,
desc
=
None
,
*
_args
,
**
_kwargs
):
super
().
__init__
()
self
.
print_interval
=
print_interval
class
progress_bar
(
object
):
"""Abstract class for progress bars."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
self
.
iterable
=
iterable
self
.
desc
=
desc
self
.
epoch
=
epoch
self
.
prefix
=
''
if
epoch
is
not
None
:
self
.
prefix
+=
f
'| epoch
{
epoch
:
03
d
}
'
if
prefix
is
not
None
:
self
.
prefix
+=
f
' |
{
prefix
}
'
def
__enter__
(
self
):
return
self
...
...
@@ -44,36 +38,149 @@ class simple_progress_bar(object):
return
False
def
__iter__
(
self
):
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
if
i
>
0
and
i
%
self
.
print_interval
==
0
:
desc
=
''
if
self
.
desc
is
None
else
'{}: '
.
format
(
self
.
desc
)
msg
=
'{}{:5d} / {:d} {}
\n
'
.
format
(
desc
,
i
,
size
,
self
.
postfix
)
sys
.
stdout
.
write
(
msg
)
sys
.
stdout
.
flush
()
raise
NotImplementedError
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
raise
NotImplementedError
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
raise
NotImplementedError
def
set_postfix
(
self
,
ordered_dict
=
None
,
refresh
=
True
,
**
kwargs
):
# Sort in alphabetical order to be more deterministic
postfix
=
OrderedDict
([]
if
ordered_dict
is
None
else
ordered_dict
)
for
key
in
sorted
(
kwargs
.
keys
()):
postfix
[
key
]
=
kwargs
[
key
]
def
_str_commas
(
self
,
stats
):
return
', '
.
join
(
key
+
'='
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_str_pipes
(
self
,
stats
):
return
' | '
.
join
(
key
+
' '
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_format_stats
(
self
,
stats
):
postfix
=
OrderedDict
(
stats
)
# Preprocess stats according to datatype
for
key
in
postfix
.
keys
():
# Number: limit the length of the string
if
isinstance
(
postfix
[
key
],
Number
):
postfix
[
key
]
=
'{0:2.3g}'
.
format
(
postfix
[
key
])
postfix
[
key
]
=
'{:g}'
.
format
(
postfix
[
key
])
# Meter: display both current and average value
elif
isinstance
(
postfix
[
key
],
AverageMeter
):
postfix
[
key
]
=
'{:.2f} ({:.2f})'
.
format
(
postfix
[
key
].
val
,
postfix
[
key
].
avg
)
# Else for any other type, try to get the string conversion
elif
not
isinstance
(
postfix
[
key
],
str
):
postfix
[
key
]
=
str
(
postfix
[
key
])
# Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix
self
.
postfix
=
', '
.
join
(
key
+
'='
+
postfix
[
key
].
strip
()
for
key
in
postfix
.
keys
())
@
classmethod
def
write
(
cls
,
s
,
file
=
None
,
end
=
"
\n
"
):
fp
=
file
if
file
is
not
None
else
sys
.
stdout
fp
.
write
(
s
)
fp
.
write
(
end
)
fp
.
flush
()
return
postfix
class
json_progress_bar
(
progress_bar
):
"""Log output in JSON format."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
postfix_json
=
None
def
__iter__
(
self
):
size
=
float
(
len
(
self
.
iterable
))
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
if
self
.
stats
is
not
None
and
i
>
0
and
\
self
.
log_interval
is
not
None
and
i
%
self
.
log_interval
==
0
:
update
=
self
.
epoch
+
float
(
i
/
size
)
if
self
.
epoch
is
not
None
else
None
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
,
update
=
update
)
print
(
"sweep_log: "
+
json
.
dumps
(
stats
))
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
stats
=
stats
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
)
print
(
"sweep_log: "
+
json
.
dumps
(
stats
))
def
_format_stats
(
self
,
stats
,
epoch
=
None
,
update
=
None
):
postfix
=
OrderedDict
()
if
epoch
is
not
None
:
postfix
[
'epoch'
]
=
epoch
if
update
is
not
None
:
postfix
[
'update'
]
=
update
# Preprocess stats according to datatype
for
key
in
stats
.
keys
():
# Meter: display both current and average value
if
isinstance
(
stats
[
key
],
AverageMeter
):
postfix
[
key
]
=
stats
[
key
].
val
postfix
[
key
+
'_avg'
]
=
stats
[
key
].
avg
else
:
postfix
[
key
]
=
stats
[
key
]
return
postfix
class
noop_progress_bar
(
progress_bar
):
"""No logging."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
def
__iter__
(
self
):
for
obj
in
self
.
iterable
:
yield
obj
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
pass
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
pass
class
simple_progress_bar
(
progress_bar
):
"""A minimal logger for non-TTY environments."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
stats
=
None
def
__iter__
(
self
):
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
if
self
.
stats
is
not
None
and
i
>
0
and
\
self
.
log_interval
is
not
None
and
i
%
self
.
log_interval
==
0
:
postfix
=
self
.
_str_commas
(
self
.
stats
)
print
(
f
'
{
self
.
prefix
}
:
{
i
:
5
d
}
/
{
size
:
d
}
{
postfix
}
'
)
sys
.
stdout
.
flush
()
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
stats
=
self
.
_format_stats
(
stats
)
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
print
(
f
'
{
self
.
prefix
}
|
{
postfix
}
'
)
sys
.
stdout
.
flush
()
class
tqdm_progress_bar
(
progress_bar
):
"""Log to tqdm."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
tqdm
=
tqdm
(
iterable
,
self
.
prefix
,
leave
=
False
)
def
__iter__
(
self
):
return
iter
(
self
.
tqdm
)
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
tqdm
.
set_postfix
(
self
.
_format_stats
(
stats
),
refresh
=
False
)
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
self
.
tqdm
.
write
(
f
'
{
self
.
tqdm
.
desc
}
|
{
postfix
}
'
)
fairseq/utils.py
View file @
c6d6256b
...
...
@@ -14,7 +14,7 @@ import traceback
from
torch.autograd
import
Variable
from
torch.serialization
import
default_restore_location
from
fairseq
import
criterions
,
data
,
models
,
tokenizer
from
fairseq
import
criterions
,
data
,
models
,
progress_bar
,
tokenizer
def
parse_args_and_arch
(
parser
):
...
...
@@ -36,6 +36,18 @@ def build_criterion(args, src_dict, dst_dict):
return
criterions
.
CrossEntropyCriterion
(
args
,
dst_dict
)
def
build_progress_bar
(
args
,
iterator
,
epoch
=
None
,
prefix
=
None
):
if
args
.
log_format
==
'json'
:
bar
=
progress_bar
.
json_progress_bar
(
iterator
,
epoch
,
prefix
,
args
.
log_interval
)
elif
args
.
log_format
==
'none'
:
bar
=
progress_bar
.
noop_progress_bar
(
iterator
,
epoch
,
prefix
)
elif
args
.
log_format
==
'tqdm'
:
bar
=
progress_bar
.
tqdm_progress_bar
(
iterator
,
epoch
,
prefix
)
else
:
bar
=
progress_bar
.
simple_progress_bar
(
iterator
,
epoch
,
prefix
,
args
.
log_interval
)
return
bar
def
torch_persistent_save
(
*
args
,
**
kwargs
):
for
i
in
range
(
3
):
try
:
...
...
generate.py
View file @
c6d6256b
...
...
@@ -11,7 +11,6 @@ import torch
from
fairseq
import
bleu
,
data
,
options
,
tokenizer
,
utils
from
fairseq.meters
import
StopwatchMeter
,
TimeMeter
from
fairseq.progress_bar
import
progress_bar
from
fairseq.sequence_generator
import
SequenceGenerator
...
...
@@ -27,10 +26,10 @@ def main():
options
.
add_generation_args
(
parser
)
args
=
parser
.
parse_args
()
if
args
.
no_progress_bar
:
args
.
log_format
=
'none'
print
(
args
)
if
args
.
no_progress_bar
:
progress_bar
.
enabled
=
False
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dataset
...
...
@@ -74,7 +73,7 @@ def main():
args
.
gen_subset
,
max_sentences
=
args
.
batch_size
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
num_sentences
=
0
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
with
utils
.
build_
progress_bar
(
args
,
itr
)
as
t
:
wps_meter
=
TimeMeter
()
gen_timer
=
StopwatchMeter
()
translations
=
translator
.
generate_batched_itr
(
...
...
@@ -119,7 +118,7 @@ def main():
scorer
.
add
(
target_tokens
,
hypo_tokens
)
wps_meter
.
update
(
src_tokens
.
size
(
0
))
t
.
set_postfix
(
wps
=
'{:5d}'
.
format
(
round
(
wps_meter
.
avg
)
),
refresh
=
False
)
t
.
log
({
'wps'
:
round
(
wps_meter
.
avg
)
}
)
num_sentences
+=
1
print
(
'| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'
.
format
(
...
...
train.py
View file @
c6d6256b
...
...
@@ -15,7 +15,6 @@ import math
from
fairseq
import
data
,
options
,
utils
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
,
TimeMeter
from
fairseq.multiprocessing_trainer
import
MultiprocessingTrainer
from
fairseq.progress_bar
import
progress_bar
def
main
():
...
...
@@ -38,8 +37,7 @@ def main():
args
=
utils
.
parse_args_and_arch
(
parser
)
if
args
.
no_progress_bar
:
progress_bar
.
enabled
=
False
progress_bar
.
print_interval
=
args
.
log_interval
args
.
log_format
=
'simple'
if
not
os
.
path
.
exists
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
...
...
@@ -124,7 +122,7 @@ def main():
def
get_perplexity
(
loss
):
try
:
return
math
.
pow
(
2
,
loss
)
return
round
(
math
.
pow
(
2
,
loss
)
,
2
)
except
OverflowError
:
return
float
(
'inf'
)
...
...
@@ -149,9 +147,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
clip_meter
=
AverageMeter
()
# % of updates clipped
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
desc
=
'| epoch {:03d}'
.
format
(
epoch
)
lr
=
trainer
.
get_lr
()
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
with
utils
.
build_
progress_bar
(
args
,
itr
,
epoch
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
loss_dict
=
trainer
.
train_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
...
...
@@ -168,16 +165,16 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
extra_postfix
=
[]
for
k
,
v
in
loss_dict
.
items
():
extra_meters
[
k
].
update
(
v
)
extra_postfix
.
append
((
k
,
'{:.4f}'
.
format
(
extra_meters
[
k
].
avg
))
)
extra_postfix
.
append
((
k
,
extra_meters
[
k
].
avg
))
t
.
set_postfix
(
collections
.
OrderedDict
([
(
'loss'
,
'{:.2f} ({:.2f})'
.
format
(
loss
,
loss_meter
.
avg
)
),
(
'wps'
,
'{:5d}'
.
format
(
round
(
wps_meter
.
avg
))
)
,
(
'wpb'
,
'{:5d}'
.
format
(
round
(
wpb_meter
.
avg
))
)
,
(
'bsz'
,
'{:5d}'
.
format
(
round
(
bsz_meter
.
avg
))
)
,
t
.
log
(
collections
.
OrderedDict
([
(
'loss'
,
loss_meter
),
(
'wps'
,
round
(
wps_meter
.
avg
)),
(
'wpb'
,
round
(
wpb_meter
.
avg
)),
(
'bsz'
,
round
(
bsz_meter
.
avg
)),
(
'lr'
,
lr
),
(
'clip'
,
'{:
3
.0
f}
%'
.
format
(
clip_meter
.
avg
*
100
)),
]
+
extra_postfix
)
,
refresh
=
False
)
(
'clip'
,
'{:.0%
}
'
.
format
(
clip_meter
.
avg
)),
]
+
extra_postfix
))
if
i
==
0
:
# ignore the first mini-batch in words-per-second calculation
...
...
@@ -185,17 +182,19 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
save_checkpoint
(
trainer
,
args
,
epoch
,
i
+
1
)
fmt
=
desc
+
' | train loss {:2.2f} | train ppl {:3.2f}'
.
format
(
loss_meter
.
avg
,
get_perplexity
(
loss_meter
.
avg
))
fmt
+=
' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
.
format
(
round
(
wps_meter
.
elapsed_time
),
round
(
wps_meter
.
avg
),
round
(
wpb_meter
.
avg
))
fmt
+=
' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'
.
format
(
round
(
bsz_meter
.
avg
),
lr
,
clip_meter
.
avg
*
100
)
fmt
+=
''
.
join
(
' | {} {:.4f}'
.
format
(
k
,
meter
.
avg
)
t
.
print
(
collections
.
OrderedDict
([
(
'train loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'train ppl'
,
get_perplexity
(
loss_meter
.
avg
)),
(
's/checkpoint'
,
round
(
wps_meter
.
elapsed_time
)),
(
'words/s'
,
round
(
wps_meter
.
avg
)),
(
'words/batch'
,
round
(
wpb_meter
.
avg
)),
(
'bsz'
,
round
(
bsz_meter
.
avg
)),
(
'lr'
,
lr
),
(
'clip'
,
'{:3.0f}%'
.
format
(
clip_meter
.
avg
*
100
)),
]
+
[
(
k
,
meter
.
avg
)
for
k
,
meter
in
extra_meters
.
items
()
)
t
.
write
(
fmt
)
]))
def
save_checkpoint
(
trainer
,
args
,
epoch
,
batch_offset
,
val_loss
):
...
...
@@ -232,8 +231,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
desc
=
'| epoch {:03d} |
valid on
\'
{}
\'
subset'
.
format
(
epoch
,
subset
)
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
prefix
=
'
valid on
\'
{}
\'
subset'
.
format
(
subset
)
with
utils
.
build_
progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
ngpus
):
loss_dict
=
trainer
.
valid_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
...
...
@@ -245,23 +244,22 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
extra_postfix
=
[]
for
k
,
v
in
loss_dict
.
items
():
extra_meters
[
k
].
update
(
v
)
extra_postfix
.
append
((
k
,
'{:.4f}'
.
format
(
extra_meters
[
k
].
avg
))
)
extra_postfix
.
append
((
k
,
extra_meters
[
k
].
avg
))
t
.
set_postfix
(
collections
.
OrderedDict
([
(
'loss'
,
'{:.2f}'
.
format
(
loss_meter
.
avg
)),
]
+
extra_postfix
)
,
refresh
=
False
)
t
.
log
(
collections
.
OrderedDict
([
(
'
valid
loss'
,
round
(
loss_meter
.
avg
,
2
)),
]
+
extra_postfix
))
val_loss
=
loss_meter
.
avg
fmt
=
desc
+
' | valid loss {:2.2f} | valid ppl {:3.2f}'
.
format
(
val
_loss
,
get_perplexity
(
val_
loss
))
fmt
+=
''
.
join
(
' | {} {:.4f}'
.
format
(
k
,
meter
.
avg
)
t
.
print
(
collections
.
OrderedDict
([
(
'valid loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'
val
id ppl'
,
get_perplexity
(
loss
_meter
.
avg
))
,
]
+
[
(
k
,
meter
.
avg
)
for
k
,
meter
in
extra_meters
.
items
()
)
t
.
write
(
fmt
)
]))
# update and return the learning rate
return
val_
loss
return
loss
_meter
.
avg
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment