Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
7aba6084
Commit
7aba6084
authored
Oct 11, 2017
by
Myle Ott
Committed by
GitHub
Oct 11, 2017
Browse files
Update progress_bar to be more robust to changes in tqdm (#21)
parent
2ad58885
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
14 deletions
+37
-14
fairseq/progress_bar.py
fairseq/progress_bar.py
+34
-11
generate.py
generate.py
+1
-1
train.py
train.py
+2
-2
No files found.
fairseq/progress_bar.py
View file @
7aba6084
...
...
@@ -7,9 +7,11 @@
#
"""
Progress bar wrapper around tqdm which handles non-
tty
outputs
Progress bar wrapper around tqdm which handles non-
TTY
outputs
.
"""
from
collections
import
OrderedDict
from
numbers
import
Number
import
sys
from
tqdm
import
tqdm
...
...
@@ -26,30 +28,51 @@ class progress_bar(tqdm):
return
simple_progress_bar
(
cls
.
print_interval
,
*
args
,
**
kwargs
)
class
simple_progress_bar
(
tqdm
):
class
simple_progress_bar
(
object
):
"""A minimal replacement for tqdm in non-TTY environments."""
def
__init__
(
self
,
print_interval
,
*
args
,
**
kwargs
):
super
(
simple_progress_bar
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
__init__
(
self
,
print_interval
,
iterable
,
desc
,
*
_
args
,
**
_
kwargs
):
super
(
).
__init__
(
)
self
.
print_interval
=
print_interval
self
.
iterable
=
iterable
self
.
desc
=
desc
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
exc
):
return
False
def
__iter__
(
self
):
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
if
i
>
0
and
i
%
self
.
print_interval
==
0
:
msg
=
'{} {:5d} / {:d} {}
\n
'
.
format
(
self
.
desc
,
i
,
size
,
self
.
postfix
)
msg
=
'{}
:
{:5d} / {:d} {}
\n
'
.
format
(
self
.
desc
,
i
,
size
,
self
.
postfix
)
sys
.
stdout
.
write
(
msg
)
sys
.
stdout
.
flush
()
def
set_postfix
(
self
,
ordered_dict
=
None
,
refresh
=
True
,
**
kwargs
):
# Sort in alphabetical order to be more deterministic
postfix
=
OrderedDict
([]
if
ordered_dict
is
None
else
ordered_dict
)
for
key
in
sorted
(
kwargs
.
keys
()):
postfix
[
key
]
=
kwargs
[
key
]
# Preprocess stats according to datatype
for
key
in
postfix
.
keys
():
# Number: limit the length of the string
if
isinstance
(
postfix
[
key
],
Number
):
postfix
[
key
]
=
'{0:2.3g}'
.
format
(
postfix
[
key
])
# Else for any other type, try to get the string conversion
elif
not
isinstance
(
postfix
[
key
],
str
):
postfix
[
key
]
=
str
(
postfix
[
key
])
# Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix
self
.
postfix
=
', '
.
join
(
key
+
'='
+
postfix
[
key
].
strip
()
for
key
in
postfix
.
keys
())
@
classmethod
def
write
(
cls
,
s
,
file
=
None
,
end
=
"
\n
"
):
fp
=
file
if
file
is
not
None
else
sys
.
stdout
fp
.
write
(
s
)
fp
.
write
(
end
)
fp
.
flush
()
@
staticmethod
def
status_printer
(
file
):
def
print_status
(
s
):
pass
return
print_status
generate.py
View file @
7aba6084
...
...
@@ -135,7 +135,7 @@ def main():
display_hypotheses
(
id
,
src
,
None
,
ref
,
hypos
[:
min
(
len
(
hypos
),
args
.
nbest
)])
wps_meter
.
update
(
src
.
size
(
0
))
t
.
set_postfix
(
wps
=
'{:5d}'
.
format
(
round
(
wps_meter
.
avg
)))
t
.
set_postfix
(
wps
=
'{:5d}'
.
format
(
round
(
wps_meter
.
avg
))
,
refresh
=
False
)
num_sentences
+=
1
print
(
'| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'
.
format
(
...
...
train.py
View file @
7aba6084
...
...
@@ -148,7 +148,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
(
'lr'
,
lr
),
(
'clip'
,
'{:3.0f}%'
.
format
(
clip_meter
.
avg
*
100
)),
(
'gnorm'
,
'{:.4f}'
.
format
(
gnorm_meter
.
avg
)),
]))
])
,
refresh
=
False
)
if
i
==
0
:
# ignore the first mini-batch in words-per-second calculation
...
...
@@ -182,7 +182,7 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
loss
=
trainer
.
valid_step
(
sample
,
criterion
)
loss_meter
.
update
(
loss
,
ntokens
)
t
.
set_postfix
(
loss
=
'{:.2f}'
.
format
(
loss_meter
.
avg
))
t
.
set_postfix
(
loss
=
'{:.2f}'
.
format
(
loss_meter
.
avg
)
,
refresh
=
False
)
val_loss
=
loss_meter
.
avg
t
.
write
(
desc
+
' | valid loss {:2.2f} | valid ppl {:3.2f}'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment