Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
e734b0fa
Commit
e734b0fa
authored
Sep 14, 2017
by
Sergey Edunov
Browse files
Initial commit
parents
Changes
46
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
561 additions
and
0 deletions
+561
-0
scripts/build_sym_alignment.py
scripts/build_sym_alignment.py
+99
-0
scripts/convert_dictionary.lua
scripts/convert_dictionary.lua
+36
-0
scripts/convert_model.lua
scripts/convert_model.lua
+110
-0
setup.py
setup.py
+71
-0
tests/test_label_smoothing.py
tests/test_label_smoothing.py
+35
-0
train.py
train.py
+210
-0
No files found.
scripts/build_sym_alignment.py
0 → 100644
View file @
e734b0fa
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
Use this script in order to build symmetric alignments for your translation
dataset.
This script depends on fast_align and mosesdecoder tools. You will need to
build those before running the script.
fast_align:
github: http://github.com/clab/fast_align
instructions: follow the instructions in README.md
mosesdecoder:
github: http://github.com/moses-smt/mosesdecoder
instructions: http://www.statmt.org/moses/?n=Development.GetStarted
The script produces the following files under --output_dir:
text.joined - concatenation of lines from the source_file and the
target_file.
align.forward - forward pass of fast_align.
align.backward - backward pass of fast_align.
aligned.sym_heuristic - symmetrized alignment.
"""
import
argparse
import
os
from
itertools
import
zip_longest
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'symmetric alignment builer'
)
parser
.
add_argument
(
'--fast_align_dir'
,
help
=
'path to fast_align build directory'
)
parser
.
add_argument
(
'--mosesdecoder_dir'
,
help
=
'path to mosesdecoder root directory'
)
parser
.
add_argument
(
'--sym_heuristic'
,
help
=
'heuristic to use for symmetrization'
,
default
=
'grow-diag-final-and'
)
parser
.
add_argument
(
'--source_file'
,
help
=
'path to a file with sentences '
'in the source language'
)
parser
.
add_argument
(
'--target_file'
,
help
=
'path to a file with sentences '
'in the target language'
)
parser
.
add_argument
(
'--output_dir'
,
help
=
'output directory'
)
args
=
parser
.
parse_args
()
fast_align_bin
=
os
.
path
.
join
(
args
.
fast_align_dir
,
'fast_align'
)
symal_bin
=
os
.
path
.
join
(
args
.
mosesdecoder_dir
,
'bin'
,
'symal'
)
sym_fast_align_bin
=
os
.
path
.
join
(
args
.
mosesdecoder_dir
,
'scripts'
,
'ems'
,
'support'
,
'symmetrize-fast-align.perl'
)
# create joined file
joined_file
=
os
.
path
.
join
(
args
.
output_dir
,
'text.joined'
)
with
open
(
args
.
source_file
,
'r'
)
as
src
,
open
(
args
.
target_file
,
'r'
)
as
tgt
:
with
open
(
joined_file
,
'w'
)
as
joined
:
for
s
,
t
in
zip_longest
(
src
,
tgt
):
print
(
'{} ||| {}'
.
format
(
s
.
strip
(),
t
.
strip
()),
file
=
joined
)
bwd_align_file
=
os
.
path
.
join
(
args
.
output_dir
,
'align.backward'
)
# run forward alignment
fwd_align_file
=
os
.
path
.
join
(
args
.
output_dir
,
'align.forward'
)
fwd_fast_align_cmd
=
'{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'
.
format
(
FASTALIGN
=
fast_align_bin
,
JOINED
=
joined_file
,
FWD
=
fwd_align_file
)
assert
os
.
system
(
fwd_fast_align_cmd
)
==
0
# run backward alignment
bwd_align_file
=
os
.
path
.
join
(
args
.
output_dir
,
'align.backward'
)
bwd_fast_align_cmd
=
'{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'
.
format
(
FASTALIGN
=
fast_align_bin
,
JOINED
=
joined_file
,
BWD
=
bwd_align_file
)
assert
os
.
system
(
bwd_fast_align_cmd
)
==
0
# run symmetrization
sym_out_file
=
os
.
path
.
join
(
args
.
output_dir
,
'aligned'
)
sym_cmd
=
'{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'
.
format
(
SYMFASTALIGN
=
sym_fast_align_bin
,
FWD
=
fwd_align_file
,
BWD
=
bwd_align_file
,
SRC
=
args
.
source_file
,
TGT
=
args
.
target_file
,
OUT
=
sym_out_file
,
HEURISTIC
=
args
.
sym_heuristic
,
SYMAL
=
symal_bin
)
assert
os
.
system
(
sym_cmd
)
==
0
if
__name__
==
'__main__'
:
main
()
scripts/convert_dictionary.lua
0 → 100644
View file @
e734b0fa
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Usage: convert_dictionary.lua <dict.th7>
require
'fairseq'
require
'torch'
require
'paths'
if
#
arg
<
1
then
print
(
'usage: convert_dictionary.lua <dict.th7>'
)
os.exit
(
1
)
end
if
not
paths
.
filep
(
arg
[
1
])
then
print
(
'error: file does not exit: '
..
arg
[
1
])
os.exit
(
1
)
end
dict
=
torch
.
load
(
arg
[
1
])
dst
=
paths
.
basename
(
arg
[
1
]):
gsub
(
'.th7'
,
'.txt'
)
assert
(
dst
:
match
(
'.txt$'
))
f
=
io.open
(
dst
,
'w'
)
for
idx
,
symbol
in
ipairs
(
dict
.
index_to_symbol
)
do
if
idx
>
dict
.
cutoff
then
break
end
f
:
write
(
symbol
)
f
:
write
(
' '
)
f
:
write
(
dict
.
index_to_freq
[
idx
])
f
:
write
(
'
\n
'
)
end
f
:
close
()
scripts/convert_model.lua
0 → 100644
View file @
e734b0fa
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Usage: convert_model.lua <model_epoch1.th7>
require
'torch'
local
fairseq
=
require
'fairseq'
model
=
torch
.
load
(
arg
[
1
])
function
find_weight_norm
(
container
,
module
)
for
_
,
wn
in
ipairs
(
container
:
listModules
())
do
if
torch
.
type
(
wn
)
==
'nn.WeightNorm'
and
wn
.
modules
[
1
]
==
module
then
return
wn
end
end
end
function
push_state
(
dict
,
key
,
module
)
if
torch
.
type
(
module
)
==
'nn.Linear'
then
local
wn
=
find_weight_norm
(
model
.
module
,
module
)
assert
(
wn
)
dict
[
key
..
'.weight_v'
]
=
wn
.
v
:
float
()
dict
[
key
..
'.weight_g'
]
=
wn
.
g
:
float
()
elseif
torch
.
type
(
module
)
==
'nn.TemporalConvolutionTBC'
then
local
wn
=
find_weight_norm
(
model
.
module
,
module
)
assert
(
wn
)
local
v
=
wn
.
v
:
float
():
view
(
wn
.
viewOut
):
transpose
(
2
,
3
)
dict
[
key
..
'.weight_v'
]
=
v
dict
[
key
..
'.weight_g'
]
=
wn
.
g
:
float
():
view
(
module
.
weight
:
size
(
3
),
1
,
1
)
else
dict
[
key
..
'.weight'
]
=
module
.
weight
:
float
()
end
if
module
.
bias
then
dict
[
key
..
'.bias'
]
=
module
.
bias
:
float
()
end
end
encoder_dict
=
{}
decoder_dict
=
{}
combined_dict
=
{}
function
encoder_state
(
encoder
)
luts
=
encoder
:
findModules
(
'nn.LookupTable'
)
push_state
(
encoder_dict
,
'embed_tokens'
,
luts
[
1
])
push_state
(
encoder_dict
,
'embed_positions'
,
luts
[
2
])
fcs
=
encoder
:
findModules
(
'nn.Linear'
)
assert
(
#
fcs
>=
2
)
local
nInputPlane
=
fcs
[
1
].
weight
:
size
(
1
)
push_state
(
encoder_dict
,
'fc1'
,
table.remove
(
fcs
,
1
))
push_state
(
encoder_dict
,
'fc2'
,
table.remove
(
fcs
,
#
fcs
))
for
i
,
module
in
ipairs
(
encoder
:
findModules
(
'nn.TemporalConvolutionTBC'
))
do
push_state
(
encoder_dict
,
'convolutions.'
..
tostring
(
i
-
1
),
module
)
if
nInputPlane
~=
module
.
weight
:
size
(
3
)
/
2
then
push_state
(
encoder_dict
,
'projections.'
..
tostring
(
i
-
1
),
table.remove
(
fcs
,
1
))
end
nInputPlane
=
module
.
weight
:
size
(
3
)
/
2
end
assert
(
#
fcs
==
0
)
end
function
decoder_state
(
decoder
)
luts
=
decoder
:
findModules
(
'nn.LookupTable'
)
push_state
(
decoder_dict
,
'embed_tokens'
,
luts
[
1
])
push_state
(
decoder_dict
,
'embed_positions'
,
luts
[
2
])
fcs
=
decoder
:
findModules
(
'nn.Linear'
)
local
nInputPlane
=
fcs
[
1
].
weight
:
size
(
1
)
push_state
(
decoder_dict
,
'fc1'
,
table.remove
(
fcs
,
1
))
push_state
(
decoder_dict
,
'fc2'
,
fcs
[
#
fcs
-
1
])
push_state
(
decoder_dict
,
'fc3'
,
fcs
[
#
fcs
])
table.remove
(
fcs
,
#
fcs
)
table.remove
(
fcs
,
#
fcs
)
for
i
,
module
in
ipairs
(
decoder
:
findModules
(
'nn.TemporalConvolutionTBC'
))
do
if
nInputPlane
~=
module
.
weight
:
size
(
3
)
/
2
then
push_state
(
decoder_dict
,
'projections.'
..
tostring
(
i
-
1
),
table.remove
(
fcs
,
1
))
end
nInputPlane
=
module
.
weight
:
size
(
3
)
/
2
local
prefix
=
'attention.'
..
tostring
(
i
-
1
)
push_state
(
decoder_dict
,
prefix
..
'.in_projection'
,
table.remove
(
fcs
,
1
))
push_state
(
decoder_dict
,
prefix
..
'.out_projection'
,
table.remove
(
fcs
,
1
))
push_state
(
decoder_dict
,
'convolutions.'
..
tostring
(
i
-
1
),
module
)
end
assert
(
#
fcs
==
0
)
end
_encoder
=
model
.
module
.
modules
[
2
]
_decoder
=
model
.
module
.
modules
[
3
]
encoder_state
(
_encoder
)
decoder_state
(
_decoder
)
for
k
,
v
in
pairs
(
encoder_dict
)
do
combined_dict
[
'encoder.'
..
k
]
=
v
end
for
k
,
v
in
pairs
(
decoder_dict
)
do
combined_dict
[
'decoder.'
..
k
]
=
v
end
torch
.
save
(
'state_dict.t7'
,
combined_dict
)
setup.py
0 → 100644
View file @
e734b0fa
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from
setuptools
import
setup
,
find_packages
,
Extension
from
setuptools.command.build_py
import
build_py
import
sys
from
torch.utils.ffi
import
create_extension
if
sys
.
version_info
<
(
3
,):
sys
.
exit
(
'Sorry, Python3 is required for fairseq.'
)
with
open
(
'README.md'
)
as
f
:
readme
=
f
.
read
()
with
open
(
'LICENSE'
)
as
f
:
license
=
f
.
read
()
with
open
(
'requirements.txt'
)
as
f
:
reqs
=
f
.
read
()
bleu
=
Extension
(
'fairseq.libbleu'
,
sources
=
[
'fairseq/clib/libbleu/libbleu.cpp'
,
'fairseq/clib/libbleu/module.cpp'
,
],
extra_compile_args
=
[
'-std=c++11'
],
)
conv_tbc
=
create_extension
(
'fairseq.temporal_convolution_tbc'
,
relative_to
=
'fairseq'
,
headers
=
[
'fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.h'
],
sources
=
[
'fairseq/clib/temporal_convolution_tbc/temporal_convolution_tbc.cpp'
],
define_macros
=
[(
'WITH_CUDA'
,
None
)],
with_cuda
=
True
,
extra_compile_args
=
[
'-std=c++11'
],
)
class
build_py_hook
(
build_py
):
def
run
(
self
):
conv_tbc
.
build
()
build_py
.
run
(
self
)
setup
(
name
=
'fairseq'
,
version
=
'0.1.0'
,
description
=
'Facebook AI Research Sequence-to-Sequence Toolkit'
,
long_description
=
readme
,
license
=
license
,
install_requires
=
reqs
.
strip
().
split
(
'
\n
'
),
packages
=
find_packages
(),
ext_modules
=
[
bleu
],
# build and install PyTorch extensions
package_data
=
{
'fairseq'
:
[
'temporal_convolution_tbc/*.so'
],
},
include_package_data
=
True
,
cmdclass
=
{
'build_py'
:
build_py_hook
,
},
)
tests/test_label_smoothing.py
0 → 100644
View file @
e734b0fa
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
torch
import
unittest
from
fairseq.criterions.label_smoothed_cross_entropy
import
LabelSmoothedCrossEntropy
from
torch.autograd
import
Variable
,
gradcheck
torch
.
set_default_tensor_type
(
'torch.DoubleTensor'
)
class
TestLabelSmoothing
(
unittest
.
TestCase
):
def
test_label_smoothing
(
self
):
input
=
Variable
(
torch
.
randn
(
3
,
5
),
requires_grad
=
True
)
idx
=
torch
.
rand
(
3
)
*
4
target
=
Variable
(
idx
.
long
())
criterion
=
LabelSmoothedCrossEntropy
()
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
2
,
None
),
(
input
,
target
)
))
weights
=
torch
.
ones
(
5
)
weights
[
2
]
=
0
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
None
,
weights
),
(
input
,
target
)))
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
None
,
None
),
(
input
,
target
)))
if
__name__
==
'__main__'
:
unittest
.
main
()
train.py
0 → 100644
View file @
e734b0fa
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
collections
import
os
import
torch
import
math
from
fairseq
import
bleu
,
data
,
options
,
utils
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
,
TimeMeter
from
fairseq.multiprocessing_trainer
import
MultiprocessingTrainer
from
fairseq.progress_bar
import
progress_bar
from
fairseq.sequence_generator
import
SequenceGenerator
def
main
():
parser
=
options
.
get_parser
(
'Trainer'
)
dataset_args
=
options
.
add_dataset_args
(
parser
)
dataset_args
.
add_argument
(
'--max-tokens'
,
default
=
6000
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of tokens in a batch'
)
dataset_args
.
add_argument
(
'--train-subset'
,
default
=
'train'
,
metavar
=
'SPLIT'
,
choices
=
[
'train'
,
'valid'
,
'test'
],
help
=
'data subset to use for training (train, valid, test)'
)
dataset_args
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list ofdata subsets '
' to use for validation (train, valid, valid1,test, test1)'
)
dataset_args
.
add_argument
(
'--test-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list ofdata subset '
'to use for testing (train, valid, test)'
)
options
.
add_optimization_args
(
parser
)
options
.
add_checkpoint_args
(
parser
)
options
.
add_model_args
(
parser
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
no_progress_bar
:
progress_bar
.
enabled
=
False
progress_bar
.
print_interval
=
args
.
log_interval
if
not
os
.
path
.
exists
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
torch
.
manual_seed
(
args
.
seed
)
# Load dataset
dataset
=
data
.
load_with_check
(
args
.
data
,
args
.
source_lang
,
args
.
target_lang
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
# record inferred languages in args, so that it's saved in checkpoints
args
.
source_lang
,
args
.
target_lang
=
dataset
.
src
,
dataset
.
dst
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
for
split
in
dataset
.
splits
:
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
split
,
len
(
dataset
.
splits
[
split
])))
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
num_gpus
=
torch
.
cuda
.
device_count
()
print
(
'| using {} GPUs (with max tokens per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
))
# Build model
print
(
'| model {}'
.
format
(
args
.
arch
))
model
=
utils
.
build_model
(
args
,
dataset
)
criterion
=
utils
.
build_criterion
(
args
,
dataset
)
# Start multiprocessing
trainer
=
MultiprocessingTrainer
(
args
,
model
)
# Load the latest checkpoint if one is available
epoch
,
batch_offset
=
trainer
.
load_checkpoint
(
os
.
path
.
join
(
args
.
save_dir
,
args
.
restore_file
))
# Train until the learning rate gets too small
val_loss
=
None
max_epoch
=
args
.
max_epoch
or
math
.
inf
lr
=
trainer
.
get_lr
()
train_meter
=
StopwatchMeter
()
train_meter
.
start
()
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
:
# train for one epoch
train
(
args
,
epoch
,
batch_offset
,
trainer
,
criterion
,
dataset
,
num_gpus
)
# evaluate on validate set
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
val_loss
=
validate
(
args
,
epoch
,
trainer
,
criterion
,
dataset
,
subset
,
num_gpus
)
if
k
==
0
:
if
not
args
.
no_save
:
# save checkpoint
trainer
.
save_checkpoint
(
args
,
epoch
,
0
,
val_loss
)
# only use first validation loss to update the learning schedule
lr
=
trainer
.
lr_step
(
val_loss
,
epoch
)
epoch
+=
1
batch_offset
=
0
train_meter
.
stop
()
print
(
'| done training in {:.1f} seconds'
.
format
(
train_meter
.
sum
))
# Generate on test set and compute BLEU score
for
beam
in
[
1
,
5
,
10
,
20
]:
for
subset
in
args
.
test_subset
.
split
(
','
):
scorer
=
score_test
(
args
,
trainer
.
get_model
(),
dataset
,
subset
,
beam
,
cuda_device
=
(
0
if
num_gpus
>
0
else
None
))
print
(
'| Test on {} with beam={}: {}'
.
format
(
subset
,
beam
,
scorer
.
result_string
()))
# Stop multiprocessing
trainer
.
stop
()
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
criterion
,
dataset
,
num_gpus
):
"""Train the model for one epoch."""
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_positions
=
args
.
max_positions
,
sample_without_replacement
=
args
.
sample_without_replacement
)
loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
wps_meter
=
TimeMeter
()
# words per second
clip_meter
=
AverageMeter
()
# % of updates clipped
gnorm_meter
=
AverageMeter
()
# gradient norm
desc
=
'| epoch {:03d}'
.
format
(
epoch
)
lr
=
trainer
.
get_lr
()
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
loss
,
grad_norm
=
trainer
.
train_step
(
sample
,
criterion
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
src_size
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
loss_meter
.
update
(
loss
,
ntokens
)
bsz_meter
.
update
(
src_size
)
wpb_meter
.
update
(
ntokens
)
wps_meter
.
update
(
ntokens
)
clip_meter
.
update
(
1
if
grad_norm
>
args
.
clip_norm
else
0
)
gnorm_meter
.
update
(
grad_norm
)
t
.
set_postfix
(
collections
.
OrderedDict
([
(
'loss'
,
'{:.2f} ({:.2f})'
.
format
(
loss
,
loss_meter
.
avg
)),
(
'wps'
,
'{:5d}'
.
format
(
round
(
wps_meter
.
avg
))),
(
'wpb'
,
'{:5d}'
.
format
(
round
(
wpb_meter
.
avg
))),
(
'bsz'
,
'{:5d}'
.
format
(
round
(
bsz_meter
.
avg
))),
(
'lr'
,
lr
),
(
'clip'
,
'{:3.0f}%'
.
format
(
clip_meter
.
avg
*
100
)),
(
'gnorm'
,
'{:.4f}'
.
format
(
gnorm_meter
.
avg
)),
]))
if
i
==
0
:
# ignore the first mini-batch in words-per-second calculation
wps_meter
.
reset
()
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
trainer
.
save_checkpoint
(
args
,
epoch
,
i
+
1
)
fmt
=
desc
+
' | train loss {:2.2f} | train ppl {:3.2f}'
fmt
+=
' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
fmt
+=
' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
t
.
write
(
fmt
.
format
(
loss_meter
.
avg
,
math
.
pow
(
2
,
loss_meter
.
avg
),
round
(
wps_meter
.
elapsed_time
),
round
(
wps_meter
.
avg
),
round
(
wpb_meter
.
avg
),
round
(
bsz_meter
.
avg
),
lr
,
clip_meter
.
avg
*
100
,
gnorm_meter
.
avg
))
def
validate
(
args
,
epoch
,
trainer
,
criterion
,
dataset
,
subset
,
ngpus
):
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
args
.
max_positions
)
loss_meter
=
AverageMeter
()
desc
=
'| epoch {:03d} | valid on
\'
{}
\'
subset'
.
format
(
epoch
,
subset
)
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
ngpus
):
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
loss
=
trainer
.
valid_step
(
sample
,
criterion
)
loss_meter
.
update
(
loss
,
ntokens
)
t
.
set_postfix
(
loss
=
'{:.2f}'
.
format
(
loss_meter
.
avg
))
val_loss
=
loss_meter
.
avg
t
.
write
(
desc
+
' | valid loss {:2.2f} | valid ppl {:3.2f}'
.
format
(
val_loss
,
math
.
pow
(
2
,
val_loss
)))
# update and return the learning rate
return
val_loss
def
score_test
(
args
,
model
,
dataset
,
subset
,
beam
,
cuda_device
):
"""Evaluate the model on the test set and return the BLEU scorer."""
translator
=
SequenceGenerator
([
model
],
dataset
.
dst_dict
,
beam_size
=
beam
)
if
torch
.
cuda
.
is_available
():
translator
.
cuda
()
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
4
,
max_positions
=
args
.
max_positions
)
for
_
,
_
,
ref
,
hypos
in
translator
.
generate_batched_itr
(
itr
,
cuda_device
=
cuda_device
):
scorer
.
add
(
ref
.
int
().
cpu
(),
hypos
[
0
][
'tokens'
].
int
().
cpu
())
return
scorer
if
__name__
==
'__main__'
:
main
()
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment