Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
mlperf_transformer_v0.7
Commits
9e8a8c05
Commit
9e8a8c05
authored
Oct 14, 2024
by
jerrrrry
Browse files
Initial commit
parents
Changes
209
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1531 additions
and
0 deletions
+1531
-0
implementations/pytorch/run_conversion.sh
implementations/pytorch/run_conversion.sh
+7
-0
implementations/pytorch/run_preprocessing.sh
implementations/pytorch/run_preprocessing.sh
+34
-0
implementations/pytorch/run_training.sh
implementations/pytorch/run_training.sh
+88
-0
implementations/pytorch/run_with_docker.sh
implementations/pytorch/run_with_docker.sh
+70
-0
implementations/pytorch/score.py
implementations/pytorch/score.py
+61
-0
implementations/pytorch/scripts/__init__.py
implementations/pytorch/scripts/__init__.py
+0
-0
implementations/pytorch/scripts/average_checkpoints.py
implementations/pytorch/scripts/average_checkpoints.py
+137
-0
implementations/pytorch/scripts/build_sym_alignment.py
implementations/pytorch/scripts/build_sym_alignment.py
+99
-0
implementations/pytorch/scripts/convert_dictionary.lua
implementations/pytorch/scripts/convert_dictionary.lua
+36
-0
implementations/pytorch/scripts/convert_model.lua
implementations/pytorch/scripts/convert_model.lua
+110
-0
implementations/pytorch/setup.py
implementations/pytorch/setup.py
+71
-0
implementations/pytorch/test_run.sh
implementations/pytorch/test_run.sh
+6
-0
implementations/pytorch/tests/__init__.py
implementations/pytorch/tests/__init__.py
+0
-0
implementations/pytorch/tests/test_average_checkpoints.py
implementations/pytorch/tests/test_average_checkpoints.py
+72
-0
implementations/pytorch/tests/test_binaries.py
implementations/pytorch/tests/test_binaries.py
+273
-0
implementations/pytorch/tests/test_convtbc.py
implementations/pytorch/tests/test_convtbc.py
+50
-0
implementations/pytorch/tests/test_data_utils.py
implementations/pytorch/tests/test_data_utils.py
+29
-0
implementations/pytorch/tests/test_dictionary.py
implementations/pytorch/tests/test_dictionary.py
+73
-0
implementations/pytorch/tests/test_label_smoothing.py
implementations/pytorch/tests/test_label_smoothing.py
+101
-0
implementations/pytorch/tests/test_sequence_generator.py
implementations/pytorch/tests/test_sequence_generator.py
+214
-0
No files found.
implementations/pytorch/run_conversion.sh
0 → 100644
View file @
9e8a8c05
#!/bin/bash
set
-e
SEED
=
$1
python3 convert_utf8_to_fairseq_binary.py
--data_dir
/workspace/translation/examples/translation/wmt14_en_de
implementations/pytorch/run_preprocessing.sh
0 → 100644
View file @
9e8a8c05
#!/bin/bash
set
-e
SEED
=
$1
cd
/workspace/translation
# TODO: Add SEED to process_data.py since this uses a random generator (future PR)
#export PYTHONPATH=/research/transformer/transformer:${PYTHONPATH}
# Add compliance to PYTHONPATH
# export PYTHONPATH=/mlperf/training/compliance:${PYTHONPATH}
mkdir
-p
/workspace/translation/examples/translation/wmt14_en_de
mkdir
-p
/workspace/translation/examples/translation/wmt14_en_de/utf8
cp
/workspace/translation/reference_dictionary.ende.txt /workspace/translation/examples/translation/wmt14_en_de/dict.en.txt
cp
/workspace/translation/reference_dictionary.ende.txt /workspace/translation/examples/translation/wmt14_en_de/dict.de.txt
sed
-i
"1s/^/
\'
<lua_index_compat>
\'\n
/"
/workspace/translation/examples/translation/wmt14_en_de/dict.en.txt
sed
-i
"1s/^/
\'
<lua_index_compat>
\'\n
/"
/workspace/translation/examples/translation/wmt14_en_de/dict.de.txt
# TODO: make code consistent to not look in two places (allows temporary hack above for preprocessing-vs-training)
cp
/workspace/translation/reference_dictionary.ende.txt /workspace/translation/examples/translation/wmt14_en_de/utf8/dict.en.txt
cp
/workspace/translation/reference_dictionary.ende.txt /workspace/translation/examples/translation/wmt14_en_de/utf8/dict.de.txt
#wget https://raw.githubusercontent.com/tensorflow/models/master/official/transformer/test_data/newstest2014.en -O /workspace/translation/examples/translation/wmt14_en_de/newstest2014.en
#wget https://raw.githubusercontent.com/tensorflow/models/master/official/transformer/test_data/newstest2014.de -O /workspace/translation/examples/translation/wmt14_en_de/newstest2014.de
cp
/workspace/translation/newstest2014.en /workspace/translation/examples/translation/wmt14_en_de/newstest2014.en
cp
/workspace/translation/newstest2014.de /workspace/translation/examples/translation/wmt14_en_de/newstest2014.de
python3 preprocess.py
--raw_dir
/raw_data/
--data_dir
/workspace/translation/examples/translation/wmt14_en_de
implementations/pytorch/run_training.sh
0 → 100644
View file @
9e8a8c05
#!/bin/bash
# Start timing
START
=
$(
date
+%s
)
START_FMT
=
$(
date
+%Y-%m-%d
\
%r
)
echo
"STARTING TIMING RUN AT
${
START_FMT
}
"
if
[[
${
WORLD_SIZE
:-${
SLURM_NTASKS
}}
-ne
1
]]
;
then
DISTRIBUTED_INIT_METHOD
=
"--distributed-init-method env://"
else
DISTRIBUTED_INIT_METHOD
=
"--distributed-world-size 1"
fi
# These are scanned by train.py, so make sure they are exported
export
DGXSYSTEM
export
SLURM_NTASKS_PER_NODE
export
SLURM_NNODES
declare
-a
CMD
if
[
-n
"
${
SLURM_LOCALID
-
}
"
]
;
then
# Mode 1: Slurm launched a task for each GPU and set some envvars; no need for parallel launch
if
[
"
${
SLURM_NTASKS
}
"
-gt
"
${
SLURM_JOB_NUM_NODES
}
"
]
;
then
CMD
=(
'./bind.sh'
'--cpu=exclusive'
'--'
'python'
'-u'
)
else
CMD
=(
'python'
'-u'
)
fi
else
# Mode 2: Single-node Docker; need to launch tasks with Pytorch's distributed launch
# TODO: use bind.sh instead of bind_launch.py
# torch.distributed.launch only accepts Python programs (not bash scripts) to exec
CMD
=(
'python'
'-u'
'-m'
'bind_launch'
"--nsockets_per_node=
${
DGXNSOCKET
}
"
\
"--ncores_per_socket=
${
DGXSOCKETCORES
}
"
"--nproc_per_node=
${
DGXNGPU
}
"
)
fi
"
${
CMD
[@]
}
"
train.py
${
DATASET_DIR
}
\
--seed
${
SEED
}
\
--arch
transformer_wmt_en_de_big_t2t
\
--share-all-embeddings
\
--optimizer
adam
\
--adam-betas
'(0.9, 0.997)'
\
--adam-eps
"1e-9"
\
--clip-norm
"0.0"
\
--lr-scheduler
inverse_sqrt
\
--warmup-init-lr
"0.0"
\
--warmup-updates
${
WARMUP_UPDATES
}
\
--lr
${
LEARNING_RATE
}
\
--min-lr
"0.0"
\
--dropout
"0.1"
\
--weight-decay
"0.0"
\
--criterion
label_smoothed_cross_entropy
\
--label-smoothing
"0.1"
\
--max-tokens
${
MAX_TOKENS
}
\
--max-epoch
${
NUMEPOCHS
}
\
--target-bleu
"25.0"
\
--ignore-case
\
--no-save
\
--update-freq
1
\
--fp16
\
--seq-len-multiple
2
\
--source_lang
en
\
--target_lang
de
\
--bucket_growth_factor
1.035
\
--batching_scheme
"v0p5_better"
\
--batch_multiple_strategy
"dynamic"
\
--fast-xentropy
\
--max-len-a
1
\
--max-len-b
50
\
--lenpen
0.6
\
--no-progress-bar
\
--dataloader-num-workers
2
\
--enable-dataloader-pin-memory
\
--multihead-attn-impl
'fast_with_lyrnrm_and_dropoutadd'
\
${
DISTRIBUTED_INIT_METHOD
}
\
${
EXTRA_PARAMS
}
;
ret_code
=
$?
sleep
3
if
[[
$ret_code
!=
0
]]
;
then
exit
$ret_code
;
fi
# End timing
END
=
$(
date
+%s
)
END_FMT
=
$(
date
+%Y-%m-%d
\
%r
)
echo
"ENDING TIMING RUN AT
${
END_FMT
}
"
# Report result
RESULT
=
$((
${
END
}
-
${
START
}
))
RESULT_NAME
=
"transformer"
echo
"RESULT,
${
RESULT_NAME
}
,
${
SEED
}
,
${
RESULT
}
,
${
USER
}
,
${
START_FMT
}
"
implementations/pytorch/run_with_docker.sh
0 → 100644
View file @
9e8a8c05
#!/bin/bash
set
-euxo
pipefail
# Vars without defaults
:
"
${
DGXSYSTEM
:?DGXSYSTEM
not set
}
"
:
"
${
CONT
:?CONT
not set
}
"
# Vars with defaults
:
"
${
NEXP
:
=5
}
"
:
"
${
DATESTAMP
:
=
$(
date
+
'%y%m%d%H%M%S%N'
)
}
"
:
"
${
CLEAR_CACHES
:
=1
}
"
:
"
${
DATADIR
:
=/raid/datasets/xformer_v0p6/utf8
}
"
:
"
${
LOGDIR
:
=
$(
pwd
)
/results
}
"
# Other vars
readonly
_config_file
=
"./config_
${
DGXSYSTEM
}
.sh"
readonly
_seed_override
=
${
SEED
:-}
readonly
_logfile_base
=
"
${
LOGDIR
}
/
${
DATESTAMP
}
"
readonly
_cont_name
=
translation
_cont_mounts
=(
"--volume=
${
DATADIR
}
:/data"
"--volume=
${
LOGDIR
}
:/results"
)
# Setup directories
mkdir
-p
"
${
LOGDIR
}
"
# Get list of envvars to pass to docker
source
"
${
_config_file
}
"
mapfile
-t
_config_env < <
(
env
-i
bash
-c
".
${
_config_file
}
&& compgen -e"
|
grep
-E
-v
'^(PWD|SHLVL)'
)
_config_env+
=(
SEED
)
mapfile
-t
_config_env < <
(
for
v
in
"
${
_config_env
[@]
}
"
;
do
echo
"--env=
$v
"
;
done
)
# Cleanup container
cleanup_docker
()
{
docker container
rm
-f
"
${
_cont_name
}
"
||
true
}
cleanup_docker
trap
'set -eux; cleanup_docker'
EXIT
# Setup container
nvidia-docker run
--rm
--init
--detach
\
--net
=
host
--uts
=
host
--ipc
=
host
--security-opt
=
seccomp
=
unconfined
\
--ulimit
=
stack
=
67108864
--ulimit
=
memlock
=
-1
\
--name
=
"
${
_cont_name
}
"
"
${
_cont_mounts
[@]
}
"
\
"
${
CONT
}
"
sleep
infinity
docker
exec
-it
"
${
_cont_name
}
"
true
# Run experiments
for
_experiment_index
in
$(
seq
1
"
${
NEXP
}
"
)
;
do
(
echo
"Beginning trial
${
_experiment_index
}
of
${
NEXP
}
"
# Print system info
docker
exec
-it
"
${
_cont_name
}
"
python
-c
"
import mlperf_log_utils
from mlperf_logging.mllog import constants
mlperf_log_utils.mlperf_submission_log(constants.TRANSFORMER)"
# Clear caches
if
[
"
${
CLEAR_CACHES
}
"
-eq
1
]
;
then
sync
&&
sudo
/sbin/sysctl vm.drop_caches
=
3
docker
exec
-it
"
${
_cont_name
}
"
python
-c
"
from mlperf_logging.mllog import constants
from mlperf_log_utils import log_event
log_event(key=constants.CACHE_CLEAR, value=True)"
fi
# Run experiment
export
SEED
=
${
_seed_override
:-
$RANDOM
}
docker
exec
-it
"
${
_config_env
[@]
}
"
"
${
_cont_name
}
"
./run_and_time.sh
)
|&
tee
"
${
_logfile_base
}
_
${
_experiment_index
}
.log"
done
implementations/pytorch/score.py
0 → 100644
View file @
9e8a8c05
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
argparse
import
os
import
sys
from
fairseq
import
bleu
,
tokenizer
from
fairseq.data
import
dictionary
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Command-line script for BLEU scoring.'
)
parser
.
add_argument
(
'-s'
,
'--sys'
,
default
=
'-'
,
help
=
'system output'
)
parser
.
add_argument
(
'-r'
,
'--ref'
,
required
=
True
,
help
=
'references'
)
parser
.
add_argument
(
'-o'
,
'--order'
,
default
=
4
,
metavar
=
'N'
,
type
=
int
,
help
=
'consider ngrams up to this order'
)
parser
.
add_argument
(
'--ignore-case'
,
action
=
'store_true'
,
help
=
'case-insensitive scoring'
)
args
=
parser
.
parse_args
()
print
(
args
)
assert
args
.
sys
==
'-'
or
os
.
path
.
exists
(
args
.
sys
),
\
"System output file {} does not exist"
.
format
(
args
.
sys
)
assert
os
.
path
.
exists
(
args
.
ref
),
\
"Reference file {} does not exist"
.
format
(
args
.
ref
)
dict
=
dictionary
.
Dictionary
()
def
readlines
(
fd
):
for
line
in
fd
.
readlines
():
if
args
.
ignore_case
:
yield
line
.
lower
()
else
:
yield
line
def
score
(
fdsys
):
with
open
(
args
.
ref
)
as
fdref
:
scorer
=
bleu
.
Scorer
(
dict
.
pad
(),
dict
.
eos
(),
dict
.
unk
())
for
sys_tok
,
ref_tok
in
zip
(
readlines
(
fdsys
),
readlines
(
fdref
)):
sys_tok
=
tokenizer
.
Tokenizer
.
tokenize
(
sys_tok
,
dict
)
ref_tok
=
tokenizer
.
Tokenizer
.
tokenize
(
ref_tok
,
dict
)
scorer
.
add
(
ref_tok
,
sys_tok
)
print
(
scorer
.
result_string
(
args
.
order
))
if
args
.
sys
==
'-'
:
score
(
sys
.
stdin
)
else
:
with
open
(
args
.
sys
,
'r'
)
as
f
:
score
(
f
)
if
__name__
==
'__main__'
:
main
()
implementations/pytorch/scripts/__init__.py
0 → 100644
View file @
9e8a8c05
implementations/pytorch/scripts/average_checkpoints.py
0 → 100644
View file @
9e8a8c05
#!/usr/bin/env python3
import
argparse
import
collections
import
torch
import
os
import
re
def
average_checkpoints
(
inputs
):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Args:
inputs: An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict
=
collections
.
OrderedDict
()
params_keys
=
None
new_state
=
None
for
f
in
inputs
:
state
=
torch
.
load
(
f
,
map_location
=
(
lambda
s
,
_
:
torch
.
serialization
.
default_restore_location
(
s
,
'cpu'
)
),
)
# Copies over the settings from the first checkpoint
if
new_state
is
None
:
new_state
=
state
model_params
=
state
[
'model'
]
model_params_keys
=
list
(
model_params
.
keys
())
if
params_keys
is
None
:
params_keys
=
model_params_keys
elif
params_keys
!=
model_params_keys
:
raise
KeyError
(
'For checkpoint {}, expected list of params: {}, '
'but found: {}'
.
format
(
f
,
params_keys
,
model_params_keys
)
)
for
k
in
params_keys
:
if
k
not
in
params_dict
:
params_dict
[
k
]
=
[]
p
=
model_params
[
k
]
if
isinstance
(
p
,
torch
.
HalfTensor
):
p
=
p
.
float
()
params_dict
[
k
].
append
(
p
)
averaged_params
=
collections
.
OrderedDict
()
# v should be a list of torch Tensor.
for
k
,
v
in
params_dict
.
items
():
summed_v
=
None
for
x
in
v
:
summed_v
=
summed_v
+
x
if
summed_v
is
not
None
else
x
averaged_params
[
k
]
=
summed_v
/
len
(
v
)
new_state
[
'model'
]
=
averaged_params
return
new_state
def
last_n_checkpoints
(
paths
,
n
,
update_based
):
assert
len
(
paths
)
==
1
path
=
paths
[
0
]
if
update_based
:
pt_regexp
=
re
.
compile
(
r
'checkpoint_\d+_(\d+)\.pt'
)
else
:
pt_regexp
=
re
.
compile
(
r
'checkpoint(\d+)\.pt'
)
files
=
os
.
listdir
(
path
)
entries
=
[]
for
f
in
files
:
m
=
pt_regexp
.
fullmatch
(
f
)
if
m
is
not
None
:
entries
.
append
((
int
(
m
.
group
(
1
)),
m
.
group
(
0
)))
if
len
(
entries
)
<
n
:
raise
Exception
(
'Found {} checkpoint files but need at least {}'
,
len
(
entries
),
n
)
return
[
os
.
path
.
join
(
path
,
x
[
1
])
for
x
in
sorted
(
entries
,
reverse
=
True
)[:
n
]]
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Tool to average the params of input checkpoints to '
'produce a new checkpoint'
,
)
parser
.
add_argument
(
'--inputs'
,
required
=
True
,
nargs
=
'+'
,
help
=
'Input checkpoint file paths.'
,
)
parser
.
add_argument
(
'--output'
,
required
=
True
,
metavar
=
'FILE'
,
help
=
'Write the new checkpoint containing the averaged weights to this '
'path.'
,
)
num_group
=
parser
.
add_mutually_exclusive_group
()
num_group
.
add_argument
(
'--num-epoch-checkpoints'
,
type
=
int
,
help
=
'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last this many of them.'
,
)
num_group
.
add_argument
(
'--num-update-checkpoints'
,
type
=
int
,
help
=
'if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.'
,
)
args
=
parser
.
parse_args
()
print
(
args
)
num
=
None
is_update_based
=
False
if
args
.
num_update_checkpoints
is
not
None
:
num
=
args
.
num_update_checkpoints
is_update_based
=
True
elif
args
.
num_epoch_checkpoints
is
not
None
:
num
=
args
.
num_epoch_checkpoints
if
num
is
not
None
:
args
.
inputs
=
last_n_checkpoints
(
args
.
inputs
,
num
,
is_update_based
)
print
(
'averaging checkpoints: '
,
args
.
inputs
)
new_state
=
average_checkpoints
(
args
.
inputs
)
torch
.
save
(
new_state
,
args
.
output
)
print
(
'Finished writing averaged checkpoint to {}.'
.
format
(
args
.
output
))
if
__name__
==
'__main__'
:
main
()
implementations/pytorch/scripts/build_sym_alignment.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
Use this script in order to build symmetric alignments for your translation
dataset.
This script depends on fast_align and mosesdecoder tools. You will need to
build those before running the script.
fast_align:
github: http://github.com/clab/fast_align
instructions: follow the instructions in README.md
mosesdecoder:
github: http://github.com/moses-smt/mosesdecoder
instructions: http://www.statmt.org/moses/?n=Development.GetStarted
The script produces the following files under --output_dir:
text.joined - concatenation of lines from the source_file and the
target_file.
align.forward - forward pass of fast_align.
align.backward - backward pass of fast_align.
aligned.sym_heuristic - symmetrized alignment.
"""
import
argparse
import
os
from
itertools
import
zip_longest
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'symmetric alignment builer'
)
parser
.
add_argument
(
'--fast_align_dir'
,
help
=
'path to fast_align build directory'
)
parser
.
add_argument
(
'--mosesdecoder_dir'
,
help
=
'path to mosesdecoder root directory'
)
parser
.
add_argument
(
'--sym_heuristic'
,
help
=
'heuristic to use for symmetrization'
,
default
=
'grow-diag-final-and'
)
parser
.
add_argument
(
'--source_file'
,
help
=
'path to a file with sentences '
'in the source language'
)
parser
.
add_argument
(
'--target_file'
,
help
=
'path to a file with sentences '
'in the target language'
)
parser
.
add_argument
(
'--output_dir'
,
help
=
'output directory'
)
args
=
parser
.
parse_args
()
fast_align_bin
=
os
.
path
.
join
(
args
.
fast_align_dir
,
'fast_align'
)
symal_bin
=
os
.
path
.
join
(
args
.
mosesdecoder_dir
,
'bin'
,
'symal'
)
sym_fast_align_bin
=
os
.
path
.
join
(
args
.
mosesdecoder_dir
,
'scripts'
,
'ems'
,
'support'
,
'symmetrize-fast-align.perl'
)
# create joined file
joined_file
=
os
.
path
.
join
(
args
.
output_dir
,
'text.joined'
)
with
open
(
args
.
source_file
,
'r'
)
as
src
,
open
(
args
.
target_file
,
'r'
)
as
tgt
:
with
open
(
joined_file
,
'w'
)
as
joined
:
for
s
,
t
in
zip_longest
(
src
,
tgt
):
print
(
'{} ||| {}'
.
format
(
s
.
strip
(),
t
.
strip
()),
file
=
joined
)
bwd_align_file
=
os
.
path
.
join
(
args
.
output_dir
,
'align.backward'
)
# run forward alignment
fwd_align_file
=
os
.
path
.
join
(
args
.
output_dir
,
'align.forward'
)
fwd_fast_align_cmd
=
'{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'
.
format
(
FASTALIGN
=
fast_align_bin
,
JOINED
=
joined_file
,
FWD
=
fwd_align_file
)
assert
os
.
system
(
fwd_fast_align_cmd
)
==
0
# run backward alignment
bwd_align_file
=
os
.
path
.
join
(
args
.
output_dir
,
'align.backward'
)
bwd_fast_align_cmd
=
'{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'
.
format
(
FASTALIGN
=
fast_align_bin
,
JOINED
=
joined_file
,
BWD
=
bwd_align_file
)
assert
os
.
system
(
bwd_fast_align_cmd
)
==
0
# run symmetrization
sym_out_file
=
os
.
path
.
join
(
args
.
output_dir
,
'aligned'
)
sym_cmd
=
'{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'
.
format
(
SYMFASTALIGN
=
sym_fast_align_bin
,
FWD
=
fwd_align_file
,
BWD
=
bwd_align_file
,
SRC
=
args
.
source_file
,
TGT
=
args
.
target_file
,
OUT
=
sym_out_file
,
HEURISTIC
=
args
.
sym_heuristic
,
SYMAL
=
symal_bin
)
assert
os
.
system
(
sym_cmd
)
==
0
if
__name__
==
'__main__'
:
main
()
implementations/pytorch/scripts/convert_dictionary.lua
0 → 100644
View file @
9e8a8c05
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Usage: convert_dictionary.lua <dict.th7>
require
'fairseq'
require
'torch'
require
'paths'
if
#
arg
<
1
then
print
(
'usage: convert_dictionary.lua <dict.th7>'
)
os.exit
(
1
)
end
if
not
paths
.
filep
(
arg
[
1
])
then
print
(
'error: file does not exit: '
..
arg
[
1
])
os.exit
(
1
)
end
dict
=
torch
.
load
(
arg
[
1
])
dst
=
paths
.
basename
(
arg
[
1
]):
gsub
(
'.th7'
,
'.txt'
)
assert
(
dst
:
match
(
'.txt$'
))
f
=
io.open
(
dst
,
'w'
)
for
idx
,
symbol
in
ipairs
(
dict
.
index_to_symbol
)
do
if
idx
>
dict
.
cutoff
then
break
end
f
:
write
(
symbol
)
f
:
write
(
' '
)
f
:
write
(
dict
.
index_to_freq
[
idx
])
f
:
write
(
'
\n
'
)
end
f
:
close
()
implementations/pytorch/scripts/convert_model.lua
0 → 100644
View file @
9e8a8c05
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Usage: convert_model.lua <model_epoch1.th7>
require
'torch'
local
fairseq
=
require
'fairseq'
model
=
torch
.
load
(
arg
[
1
])
function
find_weight_norm
(
container
,
module
)
for
_
,
wn
in
ipairs
(
container
:
listModules
())
do
if
torch
.
type
(
wn
)
==
'nn.WeightNorm'
and
wn
.
modules
[
1
]
==
module
then
return
wn
end
end
end
function
push_state
(
dict
,
key
,
module
)
if
torch
.
type
(
module
)
==
'nn.Linear'
then
local
wn
=
find_weight_norm
(
model
.
module
,
module
)
assert
(
wn
)
dict
[
key
..
'.weight_v'
]
=
wn
.
v
:
float
()
dict
[
key
..
'.weight_g'
]
=
wn
.
g
:
float
()
elseif
torch
.
type
(
module
)
==
'nn.TemporalConvolutionTBC'
then
local
wn
=
find_weight_norm
(
model
.
module
,
module
)
assert
(
wn
)
local
v
=
wn
.
v
:
float
():
view
(
wn
.
viewOut
):
transpose
(
2
,
3
)
dict
[
key
..
'.weight_v'
]
=
v
dict
[
key
..
'.weight_g'
]
=
wn
.
g
:
float
():
view
(
module
.
weight
:
size
(
3
),
1
,
1
)
else
dict
[
key
..
'.weight'
]
=
module
.
weight
:
float
()
end
if
module
.
bias
then
dict
[
key
..
'.bias'
]
=
module
.
bias
:
float
()
end
end
encoder_dict
=
{}
decoder_dict
=
{}
combined_dict
=
{}
function
encoder_state
(
encoder
)
luts
=
encoder
:
findModules
(
'nn.LookupTable'
)
push_state
(
encoder_dict
,
'embed_tokens'
,
luts
[
1
])
push_state
(
encoder_dict
,
'embed_positions'
,
luts
[
2
])
fcs
=
encoder
:
findModules
(
'nn.Linear'
)
assert
(
#
fcs
>=
2
)
local
nInputPlane
=
fcs
[
1
].
weight
:
size
(
1
)
push_state
(
encoder_dict
,
'fc1'
,
table.remove
(
fcs
,
1
))
push_state
(
encoder_dict
,
'fc2'
,
table.remove
(
fcs
,
#
fcs
))
for
i
,
module
in
ipairs
(
encoder
:
findModules
(
'nn.TemporalConvolutionTBC'
))
do
push_state
(
encoder_dict
,
'convolutions.'
..
tostring
(
i
-
1
),
module
)
if
nInputPlane
~=
module
.
weight
:
size
(
3
)
/
2
then
push_state
(
encoder_dict
,
'projections.'
..
tostring
(
i
-
1
),
table.remove
(
fcs
,
1
))
end
nInputPlane
=
module
.
weight
:
size
(
3
)
/
2
end
assert
(
#
fcs
==
0
)
end
function
decoder_state
(
decoder
)
luts
=
decoder
:
findModules
(
'nn.LookupTable'
)
push_state
(
decoder_dict
,
'embed_tokens'
,
luts
[
1
])
push_state
(
decoder_dict
,
'embed_positions'
,
luts
[
2
])
fcs
=
decoder
:
findModules
(
'nn.Linear'
)
local
nInputPlane
=
fcs
[
1
].
weight
:
size
(
1
)
push_state
(
decoder_dict
,
'fc1'
,
table.remove
(
fcs
,
1
))
push_state
(
decoder_dict
,
'fc2'
,
fcs
[
#
fcs
-
1
])
push_state
(
decoder_dict
,
'fc3'
,
fcs
[
#
fcs
])
table.remove
(
fcs
,
#
fcs
)
table.remove
(
fcs
,
#
fcs
)
for
i
,
module
in
ipairs
(
decoder
:
findModules
(
'nn.TemporalConvolutionTBC'
))
do
if
nInputPlane
~=
module
.
weight
:
size
(
3
)
/
2
then
push_state
(
decoder_dict
,
'projections.'
..
tostring
(
i
-
1
),
table.remove
(
fcs
,
1
))
end
nInputPlane
=
module
.
weight
:
size
(
3
)
/
2
local
prefix
=
'attention.'
..
tostring
(
i
-
1
)
push_state
(
decoder_dict
,
prefix
..
'.in_projection'
,
table.remove
(
fcs
,
1
))
push_state
(
decoder_dict
,
prefix
..
'.out_projection'
,
table.remove
(
fcs
,
1
))
push_state
(
decoder_dict
,
'convolutions.'
..
tostring
(
i
-
1
),
module
)
end
assert
(
#
fcs
==
0
)
end
_encoder
=
model
.
module
.
modules
[
2
]
_decoder
=
model
.
module
.
modules
[
3
]
encoder_state
(
_encoder
)
decoder_state
(
_decoder
)
for
k
,
v
in
pairs
(
encoder_dict
)
do
combined_dict
[
'encoder.'
..
k
]
=
v
end
for
k
,
v
in
pairs
(
decoder_dict
)
do
combined_dict
[
'decoder.'
..
k
]
=
v
end
torch
.
save
(
'state_dict.t7'
,
combined_dict
)
implementations/pytorch/setup.py
0 → 100644
View file @
9e8a8c05
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
setuptools
import
setup
,
find_packages
,
Extension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CppExtension
import
sys
if
sys
.
version_info
<
(
3
,):
sys
.
exit
(
'Sorry, Python3 is required for fairseq.'
)
with
open
(
'README.md'
)
as
f
:
readme
=
f
.
read
()
with
open
(
'LICENSE'
)
as
f
:
license
=
f
.
read
()
with
open
(
'requirements.txt'
)
as
f
:
reqs
=
f
.
read
()
bleu
=
Extension
(
'fairseq.libbleu'
,
sources
=
[
'fairseq/clib/libbleu/libbleu.cpp'
,
'fairseq/clib/libbleu/module.cpp'
,
],
extra_compile_args
=
[
'-std=c++11'
],
)
batch_utils_v0p5
=
CppExtension
(
name
=
'fairseq.data.batch_C_v0p5'
,
sources
=
[
'fairseq/data/csrc/make_batches_v0p5.cpp'
],
extra_compile_args
=
{
'cxx'
:
[
'-O2'
,],
}
)
batch_utils_v0p5_better
=
CppExtension
(
name
=
'fairseq.data.batch_C_v0p5_better'
,
sources
=
[
'fairseq/data/csrc/make_batches_v0p5_better.cpp'
],
extra_compile_args
=
{
'cxx'
:
[
'-O2'
,
'--std=c++14'
],
}
)
batch_utils_v0p6
=
CppExtension
(
name
=
'fairseq.data.batch_C_v0p6'
,
sources
=
[
'fairseq/data/csrc/make_batches_v0p6.cpp'
],
extra_compile_args
=
{
'cxx'
:
[
'-O2'
,
'--std=c++14'
],
}
)
setup
(
name
=
'fairseq'
,
version
=
'0.5.0'
,
description
=
'Facebook AI Research Sequence-to-Sequence Toolkit'
,
long_description
=
readme
,
license
=
license
,
install_requires
=
reqs
.
strip
().
split
(
'
\n
'
),
packages
=
find_packages
(),
ext_modules
=
[
bleu
,
batch_utils_v0p5
,
batch_utils_v0p5_better
,
batch_utils_v0p6
],
cmdclass
=
{
'build_ext'
:
BuildExtension
},
test_suite
=
'tests'
,
)
implementations/pytorch/test_run.sh
0 → 100644
View file @
9e8a8c05
python3
-u
train.py /raw_data
\
--seed
1234
\
--arch
transformer_wmt_en_de_big_t2t
\
--share-all-embeddings
\
--optimizer
adam
\
--ignore-case
implementations/pytorch/tests/__init__.py
0 → 100644
View file @
9e8a8c05
implementations/pytorch/tests/test_average_checkpoints.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
collections
import
os
import
tempfile
import
unittest
import
numpy
as
np
import
torch
from
scripts.average_checkpoints
import
average_checkpoints
class
TestAverageCheckpoints
(
unittest
.
TestCase
):
def
test_average_checkpoints
(
self
):
params_0
=
collections
.
OrderedDict
(
[
(
'a'
,
torch
.
DoubleTensor
([
100.0
])),
(
'b'
,
torch
.
FloatTensor
([[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]])),
(
'c'
,
torch
.
IntTensor
([
7
,
8
,
9
])),
]
)
params_1
=
collections
.
OrderedDict
(
[
(
'a'
,
torch
.
DoubleTensor
([
1.0
])),
(
'b'
,
torch
.
FloatTensor
([[
1.0
,
1.0
,
1.0
],
[
1.0
,
1.0
,
1.0
]])),
(
'c'
,
torch
.
IntTensor
([
2
,
2
,
2
])),
]
)
params_avg
=
collections
.
OrderedDict
(
[
(
'a'
,
torch
.
DoubleTensor
([
50.5
])),
(
'b'
,
torch
.
FloatTensor
([[
1.0
,
1.5
,
2.0
],
[
2.5
,
3.0
,
3.5
]])),
# We expect truncation for integer division
(
'c'
,
torch
.
IntTensor
([
4
,
5
,
5
])),
]
)
fd_0
,
path_0
=
tempfile
.
mkstemp
()
fd_1
,
path_1
=
tempfile
.
mkstemp
()
torch
.
save
(
collections
.
OrderedDict
([(
'model'
,
params_0
)]),
path_0
)
torch
.
save
(
collections
.
OrderedDict
([(
'model'
,
params_1
)]),
path_1
)
output
=
average_checkpoints
([
path_0
,
path_1
])[
'model'
]
os
.
close
(
fd_0
)
os
.
remove
(
path_0
)
os
.
close
(
fd_1
)
os
.
remove
(
path_1
)
for
(
k_expected
,
v_expected
),
(
k_out
,
v_out
)
in
zip
(
params_avg
.
items
(),
output
.
items
()):
self
.
assertEqual
(
k_expected
,
k_out
,
'Key mismatch - expected {} but found {}. '
'(Expected list of keys: {} vs actual list of keys: {})'
.
format
(
k_expected
,
k_out
,
params_avg
.
keys
(),
output
.
keys
()
)
)
np
.
testing
.
assert_allclose
(
v_expected
.
numpy
(),
v_out
.
numpy
(),
err_msg
=
'Tensor value mismatch for key {}'
.
format
(
k_expected
)
)
if
__name__
==
'__main__'
:
unittest
.
main
()
implementations/pytorch/tests/test_binaries.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
contextlib
from
io
import
StringIO
import
os
import
random
import
sys
import
tempfile
import
unittest
import
torch
from
fairseq
import
options
import
preprocess
import
train
import
generate
import
interactive
import
eval_lm
class
TestTranslation
(
unittest
.
TestCase
):
def
test_fconv
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_fconv'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
)
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
)
generate_main
(
data_dir
)
def
test_raw
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_fconv_raw'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
,
[
'--output-format'
,
'raw'
])
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
,
[
'--raw-text'
])
generate_main
(
data_dir
,
[
'--raw-text'
])
def
test_fp16
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_fp16'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
)
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
,
[
'--fp16'
])
generate_main
(
data_dir
)
def
test_update_freq
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_update_freq'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
)
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
,
[
'--update-freq'
,
'3'
])
generate_main
(
data_dir
)
def
test_lstm
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_lstm'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
)
train_translation_model
(
data_dir
,
'lstm_wiseman_iwslt_de_en'
,
[
'--encoder-layers'
,
'2'
,
'--decoder-layers'
,
'2'
,
])
generate_main
(
data_dir
)
def
test_lstm_bidirectional
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_lstm_bidirectional'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
)
train_translation_model
(
data_dir
,
'lstm'
,
[
'--encoder-layers'
,
'2'
,
'--encoder-bidirectional'
,
'--encoder-hidden-size'
,
'256'
,
'--decoder-layers'
,
'2'
,
])
generate_main
(
data_dir
)
def
test_transformer
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_transformer'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
)
train_translation_model
(
data_dir
,
'transformer_iwslt_de_en'
)
generate_main
(
data_dir
)
class
TestStories
(
unittest
.
TestCase
):
def
test_fconv_self_att_wp
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_fconv_self_att_wp'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
)
config
=
[
'--encoder-layers'
,
'[(512, 3)] * 2'
,
'--decoder-layers'
,
'[(512, 3)] * 2'
,
'--decoder-attention'
,
'True'
,
'--encoder-attention'
,
'False'
,
'--gated-attention'
,
'True'
,
'--self-attention'
,
'True'
,
'--project-input'
,
'True'
,
]
train_translation_model
(
data_dir
,
'fconv_self_att_wp'
,
config
)
generate_main
(
data_dir
)
# fusion model
os
.
rename
(
os
.
path
.
join
(
data_dir
,
'checkpoint_last.pt'
),
os
.
path
.
join
(
data_dir
,
'pretrained.pt'
))
config
.
extend
([
'--pretrained'
,
'True'
,
'--pretrained-checkpoint'
,
os
.
path
.
join
(
data_dir
,
'pretrained.pt'
),
'--save-dir'
,
os
.
path
.
join
(
data_dir
,
'fusion_model'
),
])
train_translation_model
(
data_dir
,
'fconv_self_att_wp'
,
config
)
class
TestLanguageModeling
(
unittest
.
TestCase
):
def
test_fconv_lm
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_fconv_lm'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_lm_data
(
data_dir
)
train_language_model
(
data_dir
,
'fconv_lm'
)
eval_lm_main
(
data_dir
)
def
create_dummy_data
(
data_dir
,
num_examples
=
1000
,
maxlen
=
20
):
def
_create_dummy_data
(
filename
):
data
=
torch
.
rand
(
num_examples
*
maxlen
)
data
=
97
+
torch
.
floor
(
26
*
data
).
int
()
with
open
(
os
.
path
.
join
(
data_dir
,
filename
),
'w'
)
as
h
:
offset
=
0
for
_
in
range
(
num_examples
):
ex_len
=
random
.
randint
(
1
,
maxlen
)
ex_str
=
' '
.
join
(
map
(
chr
,
data
[
offset
:
offset
+
ex_len
]))
print
(
ex_str
,
file
=
h
)
offset
+=
ex_len
_create_dummy_data
(
'train.in'
)
_create_dummy_data
(
'train.out'
)
_create_dummy_data
(
'valid.in'
)
_create_dummy_data
(
'valid.out'
)
_create_dummy_data
(
'test.in'
)
_create_dummy_data
(
'test.out'
)
def
preprocess_translation_data
(
data_dir
,
extra_flags
=
None
):
preprocess_parser
=
preprocess
.
get_parser
()
preprocess_args
=
preprocess_parser
.
parse_args
(
[
'--source-lang'
,
'in'
,
'--target-lang'
,
'out'
,
'--trainpref'
,
os
.
path
.
join
(
data_dir
,
'train'
),
'--validpref'
,
os
.
path
.
join
(
data_dir
,
'valid'
),
'--testpref'
,
os
.
path
.
join
(
data_dir
,
'test'
),
'--thresholdtgt'
,
'0'
,
'--thresholdsrc'
,
'0'
,
'--destdir'
,
data_dir
,
]
+
(
extra_flags
or
[]),
)
preprocess
.
main
(
preprocess_args
)
def
train_translation_model
(
data_dir
,
arch
,
extra_flags
=
None
):
train_parser
=
options
.
get_training_parser
()
train_args
=
options
.
parse_args_and_arch
(
train_parser
,
[
'--task'
,
'translation'
,
data_dir
,
'--save-dir'
,
data_dir
,
'--arch'
,
arch
,
'--optimizer'
,
'nag'
,
'--lr'
,
'0.05'
,
'--max-tokens'
,
'500'
,
'--max-epoch'
,
'1'
,
'--no-progress-bar'
,
'--distributed-world-size'
,
'1'
,
'--source-lang'
,
'in'
,
'--target-lang'
,
'out'
,
]
+
(
extra_flags
or
[]),
)
train
.
main
(
train_args
)
def
generate_main
(
data_dir
,
extra_flags
=
None
):
generate_parser
=
options
.
get_generation_parser
()
generate_args
=
options
.
parse_args_and_arch
(
generate_parser
,
[
data_dir
,
'--path'
,
os
.
path
.
join
(
data_dir
,
'checkpoint_last.pt'
),
'--beam'
,
'3'
,
'--batch-size'
,
'64'
,
'--max-len-b'
,
'5'
,
'--gen-subset'
,
'valid'
,
'--no-progress-bar'
,
'--print-alignment'
,
]
+
(
extra_flags
or
[]),
)
# evaluate model in batch mode
generate
.
main
(
generate_args
)
# evaluate model interactively
generate_args
.
buffer_size
=
0
generate_args
.
max_sentences
=
None
orig_stdin
=
sys
.
stdin
sys
.
stdin
=
StringIO
(
'h e l l o
\n
'
)
interactive
.
main
(
generate_args
)
sys
.
stdin
=
orig_stdin
def
preprocess_lm_data
(
data_dir
):
preprocess_parser
=
preprocess
.
get_parser
()
preprocess_args
=
preprocess_parser
.
parse_args
([
'--only-source'
,
'--trainpref'
,
os
.
path
.
join
(
data_dir
,
'train.out'
),
'--validpref'
,
os
.
path
.
join
(
data_dir
,
'valid.out'
),
'--testpref'
,
os
.
path
.
join
(
data_dir
,
'test.out'
),
'--destdir'
,
data_dir
,
])
preprocess
.
main
(
preprocess_args
)
def
train_language_model
(
data_dir
,
arch
):
train_parser
=
options
.
get_training_parser
()
train_args
=
options
.
parse_args_and_arch
(
train_parser
,
[
'--task'
,
'language_modeling'
,
data_dir
,
'--arch'
,
arch
,
'--optimizer'
,
'nag'
,
'--lr'
,
'1.0'
,
'--criterion'
,
'adaptive_loss'
,
'--adaptive-softmax-cutoff'
,
'5,10,15'
,
'--decoder-layers'
,
'[(850, 3)] * 2 + [(1024,4)]'
,
'--decoder-embed-dim'
,
'280'
,
'--max-tokens'
,
'500'
,
'--tokens-per-sample'
,
'500'
,
'--save-dir'
,
data_dir
,
'--max-epoch'
,
'1'
,
'--no-progress-bar'
,
'--distributed-world-size'
,
'1'
,
],
)
train
.
main
(
train_args
)
def
eval_lm_main
(
data_dir
):
eval_lm_parser
=
options
.
get_eval_lm_parser
()
eval_lm_args
=
options
.
parse_args_and_arch
(
eval_lm_parser
,
[
data_dir
,
'--path'
,
os
.
path
.
join
(
data_dir
,
'checkpoint_last.pt'
),
'--no-progress-bar'
,
],
)
eval_lm
.
main
(
eval_lm_args
)
if
__name__
==
'__main__'
:
unittest
.
main
()
implementations/pytorch/tests/test_convtbc.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
import
unittest
from
fairseq.modules
import
ConvTBC
import
torch.nn
as
nn
class
TestConvTBC
(
unittest
.
TestCase
):
def
test_convtbc
(
self
):
# ksz, in_channels, out_channels
conv_tbc
=
ConvTBC
(
4
,
5
,
kernel_size
=
3
,
padding
=
1
)
# out_channels, in_channels, ksz
conv1d
=
nn
.
Conv1d
(
4
,
5
,
kernel_size
=
3
,
padding
=
1
)
conv_tbc
.
weight
.
data
.
copy_
(
conv1d
.
weight
.
data
.
transpose
(
0
,
2
))
conv_tbc
.
bias
.
data
.
copy_
(
conv1d
.
bias
.
data
)
input_tbc
=
torch
.
randn
(
7
,
2
,
4
,
requires_grad
=
True
)
input1d
=
input_tbc
.
data
.
transpose
(
0
,
1
).
transpose
(
1
,
2
)
input1d
.
requires_grad
=
True
output_tbc
=
conv_tbc
(
input_tbc
)
output1d
=
conv1d
(
input1d
)
self
.
assertAlmostEqual
(
output_tbc
.
data
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
output1d
.
data
)
grad_tbc
=
torch
.
randn
(
output_tbc
.
size
())
grad1d
=
grad_tbc
.
transpose
(
0
,
1
).
transpose
(
1
,
2
).
contiguous
()
output_tbc
.
backward
(
grad_tbc
)
output1d
.
backward
(
grad1d
)
self
.
assertAlmostEqual
(
conv_tbc
.
weight
.
grad
.
data
.
transpose
(
0
,
2
),
conv1d
.
weight
.
grad
.
data
)
self
.
assertAlmostEqual
(
conv_tbc
.
bias
.
grad
.
data
,
conv1d
.
bias
.
grad
.
data
)
self
.
assertAlmostEqual
(
input_tbc
.
grad
.
data
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
input1d
.
grad
.
data
)
def
assertAlmostEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertLess
((
t1
-
t2
).
abs
().
max
(),
1e-4
)
if
__name__
==
'__main__'
:
unittest
.
main
()
implementations/pytorch/tests/test_data_utils.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
unittest
from
fairseq.data
import
data_utils
class
TestDataUtils
(
unittest
.
TestCase
):
def
test_counting_iterator
(
self
):
x
=
list
(
range
(
10
))
itr
=
data_utils
.
CountingIterator
(
x
)
self
.
assertTrue
(
itr
.
has_next
())
self
.
assertEqual
(
next
(
itr
),
0
)
self
.
assertEqual
(
next
(
itr
),
1
)
itr
.
skip
(
3
)
self
.
assertEqual
(
next
(
itr
),
5
)
itr
.
skip
(
3
)
self
.
assertEqual
(
next
(
itr
),
9
)
self
.
assertFalse
(
itr
.
has_next
())
if
__name__
==
'__main__'
:
unittest
.
main
()
implementations/pytorch/tests/test_dictionary.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
tempfile
import
unittest
import
torch
from
fairseq.data
import
Dictionary
from
fairseq.tokenizer
import
Tokenizer
class
TestDictionary
(
unittest
.
TestCase
):
def
test_finalize
(
self
):
txt
=
[
'A B C D'
,
'B C D'
,
'C D'
,
'D'
,
]
ref_ids1
=
list
(
map
(
torch
.
IntTensor
,
[
[
4
,
5
,
6
,
7
,
2
],
[
5
,
6
,
7
,
2
],
[
6
,
7
,
2
],
[
7
,
2
],
]))
ref_ids2
=
list
(
map
(
torch
.
IntTensor
,
[
[
7
,
6
,
5
,
4
,
2
],
[
6
,
5
,
4
,
2
],
[
5
,
4
,
2
],
[
4
,
2
],
]))
# build dictionary
d
=
Dictionary
()
for
line
in
txt
:
Tokenizer
.
tokenize
(
line
,
d
,
add_if_not_exist
=
True
)
def
get_ids
(
dictionary
):
ids
=
[]
for
line
in
txt
:
ids
.
append
(
Tokenizer
.
tokenize
(
line
,
dictionary
,
add_if_not_exist
=
False
))
return
ids
def
assertMatch
(
ids
,
ref_ids
):
for
toks
,
ref_toks
in
zip
(
ids
,
ref_ids
):
self
.
assertEqual
(
toks
.
size
(),
ref_toks
.
size
())
self
.
assertEqual
(
0
,
(
toks
!=
ref_toks
).
sum
().
item
())
ids
=
get_ids
(
d
)
assertMatch
(
ids
,
ref_ids1
)
# check finalized dictionary
d
.
finalize
()
finalized_ids
=
get_ids
(
d
)
assertMatch
(
finalized_ids
,
ref_ids2
)
# write to disk and reload
with
tempfile
.
NamedTemporaryFile
(
mode
=
'w'
)
as
tmp_dict
:
d
.
save
(
tmp_dict
.
name
)
d
=
Dictionary
.
load
(
tmp_dict
.
name
)
reload_ids
=
get_ids
(
d
)
assertMatch
(
reload_ids
,
ref_ids2
)
assertMatch
(
finalized_ids
,
reload_ids
)
if
__name__
==
'__main__'
:
unittest
.
main
()
implementations/pytorch/tests/test_label_smoothing.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
argparse
import
copy
import
unittest
import
torch
from
fairseq.criterions.cross_entropy
import
CrossEntropyCriterion
from
fairseq.criterions.label_smoothed_cross_entropy
import
LabelSmoothedCrossEntropyCriterion
import
tests.utils
as
test_utils
class
TestLabelSmoothing
(
unittest
.
TestCase
):
def
setUp
(
self
):
# build dictionary
self
.
d
=
test_utils
.
dummy_dictionary
(
3
)
vocab
=
len
(
self
.
d
)
self
.
assertEqual
(
vocab
,
4
+
3
)
# 4 special + 3 tokens
self
.
assertEqual
(
self
.
d
.
pad
(),
1
)
self
.
assertEqual
(
self
.
d
.
eos
(),
2
)
self
.
assertEqual
(
self
.
d
.
unk
(),
3
)
pad
,
eos
,
unk
,
w1
,
w2
,
w3
=
1
,
2
,
3
,
4
,
5
,
6
# noqa: F841
# build dataset
self
.
data
=
[
# the first batch item has padding
{
'source'
:
torch
.
LongTensor
([
w1
,
eos
]),
'target'
:
torch
.
LongTensor
([
w1
,
eos
])},
{
'source'
:
torch
.
LongTensor
([
w1
,
eos
]),
'target'
:
torch
.
LongTensor
([
w1
,
w1
,
eos
])},
]
self
.
sample
=
next
(
test_utils
.
dummy_dataloader
(
self
.
data
))
# build model
self
.
args
=
argparse
.
Namespace
()
self
.
args
.
sentence_avg
=
False
self
.
args
.
probs
=
torch
.
FloatTensor
([
# pad eos unk w1 w2 w3
[
0.05
,
0.05
,
0.1
,
0.05
,
0.3
,
0.4
,
0.05
],
[
0.05
,
0.10
,
0.2
,
0.05
,
0.2
,
0.3
,
0.10
],
[
0.05
,
0.15
,
0.3
,
0.05
,
0.1
,
0.2
,
0.15
],
]).
unsqueeze
(
0
).
expand
(
2
,
3
,
7
)
# add batch dimension
self
.
task
=
test_utils
.
TestTranslationTask
.
setup_task
(
self
.
args
,
self
.
d
,
self
.
d
)
self
.
model
=
self
.
task
.
build_model
(
self
.
args
)
def
test_nll_loss
(
self
):
self
.
args
.
label_smoothing
=
0.1
nll_crit
=
CrossEntropyCriterion
(
self
.
args
,
self
.
task
)
smooth_crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
task
)
nll_loss
,
nll_sample_size
,
nll_logging_output
=
nll_crit
(
self
.
model
,
self
.
sample
)
smooth_loss
,
smooth_sample_size
,
smooth_logging_output
=
smooth_crit
(
self
.
model
,
self
.
sample
)
self
.
assertLess
(
abs
(
nll_loss
-
nll_logging_output
[
'loss'
]),
1e-6
)
self
.
assertLess
(
abs
(
nll_loss
-
smooth_logging_output
[
'nll_loss'
]),
1e-6
)
def
test_padding
(
self
):
self
.
args
.
label_smoothing
=
0.1
crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
task
)
loss
,
_
,
logging_output
=
crit
(
self
.
model
,
self
.
sample
)
def
get_one_no_padding
(
idx
):
# create a new sample with just a single batch item so that there's
# no padding
sample1
=
next
(
test_utils
.
dummy_dataloader
([
self
.
data
[
idx
]]))
args1
=
copy
.
copy
(
self
.
args
)
args1
.
probs
=
args1
.
probs
[
idx
,
:,
:].
unsqueeze
(
0
)
model1
=
self
.
task
.
build_model
(
args1
)
loss1
,
_
,
_
=
crit
(
model1
,
sample1
)
return
loss1
loss1
=
get_one_no_padding
(
0
)
loss2
=
get_one_no_padding
(
1
)
self
.
assertAlmostEqual
(
loss
,
loss1
+
loss2
)
def
test_reduction
(
self
):
self
.
args
.
label_smoothing
=
0.1
crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
task
)
loss
,
_
,
logging_output
=
crit
(
self
.
model
,
self
.
sample
,
reduce
=
True
)
unreduced_loss
,
_
,
_
=
crit
(
self
.
model
,
self
.
sample
,
reduce
=
False
)
self
.
assertAlmostEqual
(
loss
,
unreduced_loss
.
sum
())
def
test_zero_eps
(
self
):
self
.
args
.
label_smoothing
=
0.0
nll_crit
=
CrossEntropyCriterion
(
self
.
args
,
self
.
task
)
smooth_crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
task
)
nll_loss
,
nll_sample_size
,
nll_logging_output
=
nll_crit
(
self
.
model
,
self
.
sample
)
smooth_loss
,
smooth_sample_size
,
smooth_logging_output
=
smooth_crit
(
self
.
model
,
self
.
sample
)
self
.
assertAlmostEqual
(
nll_loss
,
smooth_loss
)
def
assertAlmostEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertLess
((
t1
-
t2
).
abs
().
max
(),
1e-6
)
if
__name__
==
'__main__'
:
unittest
.
main
()
implementations/pytorch/tests/test_sequence_generator.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
argparse
import
unittest
import
torch
from
fairseq.sequence_generator
import
SequenceGenerator
import
tests.utils
as
test_utils
class
TestSequenceGenerator
(
unittest
.
TestCase
):
def
setUp
(
self
):
# construct dummy dictionary
d
=
test_utils
.
dummy_dictionary
(
vocab_size
=
2
)
self
.
assertEqual
(
d
.
pad
(),
1
)
self
.
assertEqual
(
d
.
eos
(),
2
)
self
.
assertEqual
(
d
.
unk
(),
3
)
self
.
eos
=
d
.
eos
()
self
.
w1
=
4
self
.
w2
=
5
# construct source data
self
.
src_tokens
=
torch
.
LongTensor
([
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
])
self
.
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
args
=
argparse
.
Namespace
()
unk
=
0.
args
.
beam_probs
=
[
# step 0:
torch
.
FloatTensor
([
# eos w1 w2
# sentence 1:
[
0.0
,
unk
,
0.9
,
0.1
],
# beam 1
[
0.0
,
unk
,
0.9
,
0.1
],
# beam 2
# sentence 2:
[
0.0
,
unk
,
0.7
,
0.3
],
[
0.0
,
unk
,
0.7
,
0.3
],
]),
# step 1:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
1.0
,
unk
,
0.0
,
0.0
],
# w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[
0.0
,
unk
,
0.9
,
0.1
],
# w2: 0.1
# sentence 2:
[
0.25
,
unk
,
0.35
,
0.4
],
# w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[
0.00
,
unk
,
0.10
,
0.9
],
# w2: 0.3
]),
# step 2:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
0.0
,
unk
,
0.1
,
0.9
],
# w2 w1: 0.1*0.9
[
0.6
,
unk
,
0.2
,
0.2
],
# w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[
0.60
,
unk
,
0.4
,
0.00
],
# w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[
0.01
,
unk
,
0.0
,
0.99
],
# w2 w2: 0.3*0.9
]),
# step 3:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
1.0
,
unk
,
0.0
,
0.0
],
# w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[
1.0
,
unk
,
0.0
,
0.0
],
# w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[
0.1
,
unk
,
0.5
,
0.4
],
# w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[
1.0
,
unk
,
0.0
,
0.0
],
# w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
]),
]
task
=
test_utils
.
TestTranslationTask
.
setup_task
(
args
,
d
,
d
)
self
.
model
=
task
.
build_model
(
args
)
self
.
tgt_dict
=
task
.
target_dictionary
def
test_with_normalization
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
# sentence 1, beam 2
self
.
assertHypoTokens
(
hypos
[
0
][
1
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
1
],
[
0.1
,
0.9
,
0.9
,
1.0
])
# sentence 2, beam 1
self
.
assertHypoTokens
(
hypos
[
1
][
0
],
[
w1
,
w2
,
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
0
],
[
0.7
,
0.4
,
0.4
,
1.0
])
# sentence 2, beam 2
self
.
assertHypoTokens
(
hypos
[
1
][
1
],
[
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
1
],
[
0.7
,
0.4
,
0.6
])
def
test_without_normalization
(
self
):
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
],
normalized
=
False
)
# sentence 1, beam 2
self
.
assertHypoTokens
(
hypos
[
0
][
1
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
1
],
[
0.1
,
0.9
,
0.9
,
1.0
],
normalized
=
False
)
# sentence 2, beam 1
self
.
assertHypoTokens
(
hypos
[
1
][
0
],
[
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
0
],
[
0.7
,
0.4
,
0.6
],
normalized
=
False
)
# sentence 2, beam 2
self
.
assertHypoTokens
(
hypos
[
1
][
1
],
[
w1
,
w2
,
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
1
],
[
0.7
,
0.4
,
0.4
,
1.0
],
normalized
=
False
)
def
test_with_lenpen_favoring_short_hypos
(
self
):
lenpen
=
0.6
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
],
lenpen
=
lenpen
)
# sentence 1, beam 2
self
.
assertHypoTokens
(
hypos
[
0
][
1
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
1
],
[
0.1
,
0.9
,
0.9
,
1.0
],
lenpen
=
lenpen
)
# sentence 2, beam 1
self
.
assertHypoTokens
(
hypos
[
1
][
0
],
[
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
0
],
[
0.7
,
0.4
,
0.6
],
lenpen
=
lenpen
)
# sentence 2, beam 2
self
.
assertHypoTokens
(
hypos
[
1
][
1
],
[
w1
,
w2
,
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
1
],
[
0.7
,
0.4
,
0.4
,
1.0
],
lenpen
=
lenpen
)
def
test_with_lenpen_favoring_long_hypos
(
self
):
lenpen
=
5.0
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.1
,
0.9
,
0.9
,
1.0
],
lenpen
=
lenpen
)
# sentence 1, beam 2
self
.
assertHypoTokens
(
hypos
[
0
][
1
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
1
],
[
0.9
,
1.0
],
lenpen
=
lenpen
)
# sentence 2, beam 1
self
.
assertHypoTokens
(
hypos
[
1
][
0
],
[
w1
,
w2
,
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
0
],
[
0.7
,
0.4
,
0.4
,
1.0
],
lenpen
=
lenpen
)
# sentence 2, beam 2
self
.
assertHypoTokens
(
hypos
[
1
][
1
],
[
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
1
],
[
0.7
,
0.4
,
0.6
],
lenpen
=
lenpen
)
def
test_maxlen
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
# sentence 1, beam 2
self
.
assertHypoTokens
(
hypos
[
0
][
1
],
[
w2
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
1
],
[
0.1
,
0.1
,
0.6
])
# sentence 2, beam 1
self
.
assertHypoTokens
(
hypos
[
1
][
0
],
[
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
0
],
[
0.7
,
0.4
,
0.6
])
# sentence 2, beam 2
self
.
assertHypoTokens
(
hypos
[
1
][
1
],
[
w2
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
1
],
[
0.3
,
0.9
,
0.01
])
def
test_no_stop_early
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
# sentence 1, beam 2
self
.
assertHypoTokens
(
hypos
[
0
][
1
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
1
],
[
0.1
,
0.9
,
0.9
,
1.0
])
# sentence 2, beam 1
self
.
assertHypoTokens
(
hypos
[
1
][
0
],
[
w2
,
w2
,
w2
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
0
],
[
0.3
,
0.9
,
0.99
,
0.4
,
1.0
])
# sentence 2, beam 2
self
.
assertHypoTokens
(
hypos
[
1
][
1
],
[
w1
,
w2
,
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
1
],
[
0.7
,
0.4
,
0.4
,
1.0
])
def
assertHypoTokens
(
self
,
hypo
,
tokens
):
self
.
assertTensorEqual
(
hypo
[
'tokens'
],
torch
.
LongTensor
(
tokens
))
def
assertHypoScore
(
self
,
hypo
,
pos_probs
,
normalized
=
True
,
lenpen
=
1.
):
pos_scores
=
torch
.
FloatTensor
(
pos_probs
).
log
()
self
.
assertAlmostEqual
(
hypo
[
'positional_scores'
],
pos_scores
)
self
.
assertEqual
(
pos_scores
.
numel
(),
hypo
[
'tokens'
].
numel
())
score
=
pos_scores
.
sum
()
if
normalized
:
score
/=
pos_scores
.
numel
()
**
lenpen
self
.
assertLess
(
abs
(
score
-
hypo
[
'score'
]),
1e-6
)
def
assertAlmostEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertLess
((
t1
-
t2
).
abs
().
max
(),
1e-4
)
def
assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertEqual
(
t1
.
ne
(
t2
).
long
().
sum
(),
0
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment