Commit 9e8a8c05 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
#!/bin/bash
set -e
SEED=$1
python3 convert_utf8_to_fairseq_binary.py --data_dir /workspace/translation/examples/translation/wmt14_en_de
#!/bin/bash
set -e
SEED=$1
cd /workspace/translation
# TODO: Add SEED to process_data.py since this uses a random generator (future PR)
#export PYTHONPATH=/research/transformer/transformer:${PYTHONPATH}
# Add compliance to PYTHONPATH
# export PYTHONPATH=/mlperf/training/compliance:${PYTHONPATH}
mkdir -p /workspace/translation/examples/translation/wmt14_en_de
mkdir -p /workspace/translation/examples/translation/wmt14_en_de/utf8
cp /workspace/translation/reference_dictionary.ende.txt /workspace/translation/examples/translation/wmt14_en_de/dict.en.txt
cp /workspace/translation/reference_dictionary.ende.txt /workspace/translation/examples/translation/wmt14_en_de/dict.de.txt
sed -i "1s/^/\'<lua_index_compat>\'\n/" /workspace/translation/examples/translation/wmt14_en_de/dict.en.txt
sed -i "1s/^/\'<lua_index_compat>\'\n/" /workspace/translation/examples/translation/wmt14_en_de/dict.de.txt
# TODO: make code consistent to not look in two places (allows temporary hack above for preprocessing-vs-training)
cp /workspace/translation/reference_dictionary.ende.txt /workspace/translation/examples/translation/wmt14_en_de/utf8/dict.en.txt
cp /workspace/translation/reference_dictionary.ende.txt /workspace/translation/examples/translation/wmt14_en_de/utf8/dict.de.txt
#wget https://raw.githubusercontent.com/tensorflow/models/master/official/transformer/test_data/newstest2014.en -O /workspace/translation/examples/translation/wmt14_en_de/newstest2014.en
#wget https://raw.githubusercontent.com/tensorflow/models/master/official/transformer/test_data/newstest2014.de -O /workspace/translation/examples/translation/wmt14_en_de/newstest2014.de
cp /workspace/translation/newstest2014.en /workspace/translation/examples/translation/wmt14_en_de/newstest2014.en
cp /workspace/translation/newstest2014.de /workspace/translation/examples/translation/wmt14_en_de/newstest2014.de
python3 preprocess.py --raw_dir /raw_data/ --data_dir /workspace/translation/examples/translation/wmt14_en_de
#!/bin/bash
# Start timing
START=$(date +%s)
START_FMT=$(date +%Y-%m-%d\ %r)
echo "STARTING TIMING RUN AT ${START_FMT}"
if [[ ${WORLD_SIZE:-${SLURM_NTASKS}} -ne 1 ]]; then
DISTRIBUTED_INIT_METHOD="--distributed-init-method env://"
else
DISTRIBUTED_INIT_METHOD="--distributed-world-size 1"
fi
# These are scanned by train.py, so make sure they are exported
export DGXSYSTEM
export SLURM_NTASKS_PER_NODE
export SLURM_NNODES
declare -a CMD
if [ -n "${SLURM_LOCALID-}" ]; then
# Mode 1: Slurm launched a task for each GPU and set some envvars; no need for parallel launch
if [ "${SLURM_NTASKS}" -gt "${SLURM_JOB_NUM_NODES}" ]; then
CMD=( './bind.sh' '--cpu=exclusive' '--' 'python' '-u' )
else
CMD=( 'python' '-u' )
fi
else
# Mode 2: Single-node Docker; need to launch tasks with Pytorch's distributed launch
# TODO: use bind.sh instead of bind_launch.py
# torch.distributed.launch only accepts Python programs (not bash scripts) to exec
CMD=( 'python' '-u' '-m' 'bind_launch' "--nsockets_per_node=${DGXNSOCKET}" \
"--ncores_per_socket=${DGXSOCKETCORES}" "--nproc_per_node=${DGXNGPU}" )
fi
"${CMD[@]}" train.py ${DATASET_DIR} \
--seed ${SEED} \
--arch transformer_wmt_en_de_big_t2t \
--share-all-embeddings \
--optimizer adam \
--adam-betas '(0.9, 0.997)' \
--adam-eps "1e-9" \
--clip-norm "0.0" \
--lr-scheduler inverse_sqrt \
--warmup-init-lr "0.0" \
--warmup-updates ${WARMUP_UPDATES} \
--lr ${LEARNING_RATE} \
--min-lr "0.0" \
--dropout "0.1" \
--weight-decay "0.0" \
--criterion label_smoothed_cross_entropy \
--label-smoothing "0.1" \
--max-tokens ${MAX_TOKENS} \
--max-epoch ${NUMEPOCHS} \
--target-bleu "25.0" \
--ignore-case \
--no-save \
--update-freq 1 \
--fp16 \
--seq-len-multiple 2 \
--source_lang en \
--target_lang de \
--bucket_growth_factor 1.035 \
--batching_scheme "v0p5_better" \
--batch_multiple_strategy "dynamic" \
--fast-xentropy \
--max-len-a 1 \
--max-len-b 50 \
--lenpen 0.6 \
--no-progress-bar \
--dataloader-num-workers 2 \
--enable-dataloader-pin-memory \
--multihead-attn-impl 'fast_with_lyrnrm_and_dropoutadd' \
${DISTRIBUTED_INIT_METHOD} \
${EXTRA_PARAMS} ; ret_code=$?
sleep 3
if [[ $ret_code != 0 ]]; then exit $ret_code; fi
# End timing
END=$(date +%s)
END_FMT=$(date +%Y-%m-%d\ %r)
echo "ENDING TIMING RUN AT ${END_FMT}"
# Report result
RESULT=$(( ${END} - ${START} ))
RESULT_NAME="transformer"
echo "RESULT,${RESULT_NAME},${SEED},${RESULT},${USER},${START_FMT}"
#!/bin/bash
set -euxo pipefail
# Vars without defaults
: "${DGXSYSTEM:?DGXSYSTEM not set}"
: "${CONT:?CONT not set}"
# Vars with defaults
: "${NEXP:=5}"
: "${DATESTAMP:=$(date +'%y%m%d%H%M%S%N')}"
: "${CLEAR_CACHES:=1}"
: "${DATADIR:=/raid/datasets/xformer_v0p6/utf8}"
: "${LOGDIR:=$(pwd)/results}"
# Other vars
readonly _config_file="./config_${DGXSYSTEM}.sh"
readonly _seed_override=${SEED:-}
readonly _logfile_base="${LOGDIR}/${DATESTAMP}"
readonly _cont_name=translation
_cont_mounts=("--volume=${DATADIR}:/data" "--volume=${LOGDIR}:/results")
# Setup directories
mkdir -p "${LOGDIR}"
# Get list of envvars to pass to docker
source "${_config_file}"
mapfile -t _config_env < <(env -i bash -c ". ${_config_file} && compgen -e" | grep -E -v '^(PWD|SHLVL)')
_config_env+=(SEED)
mapfile -t _config_env < <(for v in "${_config_env[@]}"; do echo "--env=$v"; done)
# Cleanup container
cleanup_docker() {
docker container rm -f "${_cont_name}" || true
}
cleanup_docker
trap 'set -eux; cleanup_docker' EXIT
# Setup container
nvidia-docker run --rm --init --detach \
--net=host --uts=host --ipc=host --security-opt=seccomp=unconfined \
--ulimit=stack=67108864 --ulimit=memlock=-1 \
--name="${_cont_name}" "${_cont_mounts[@]}" \
"${CONT}" sleep infinity
docker exec -it "${_cont_name}" true
# Run experiments
for _experiment_index in $(seq 1 "${NEXP}"); do
(
echo "Beginning trial ${_experiment_index} of ${NEXP}"
# Print system info
docker exec -it "${_cont_name}" python -c "
import mlperf_log_utils
from mlperf_logging.mllog import constants
mlperf_log_utils.mlperf_submission_log(constants.TRANSFORMER)"
# Clear caches
if [ "${CLEAR_CACHES}" -eq 1 ]; then
sync && sudo /sbin/sysctl vm.drop_caches=3
docker exec -it "${_cont_name}" python -c "
from mlperf_logging.mllog import constants
from mlperf_log_utils import log_event
log_event(key=constants.CACHE_CLEAR, value=True)"
fi
# Run experiment
export SEED=${_seed_override:-$RANDOM}
docker exec -it "${_config_env[@]}" "${_cont_name}" ./run_and_time.sh
) |& tee "${_logfile_base}_${_experiment_index}.log"
done
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
import os
import sys
from fairseq import bleu, tokenizer
from fairseq.data import dictionary
def main():
parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
parser.add_argument('-s', '--sys', default='-', help='system output')
parser.add_argument('-r', '--ref', required=True, help='references')
parser.add_argument('-o', '--order', default=4, metavar='N',
type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring')
args = parser.parse_args()
print(args)
assert args.sys == '-' or os.path.exists(args.sys), \
"System output file {} does not exist".format(args.sys)
assert os.path.exists(args.ref), \
"Reference file {} does not exist".format(args.ref)
dict = dictionary.Dictionary()
def readlines(fd):
for line in fd.readlines():
if args.ignore_case:
yield line.lower()
else:
yield line
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))
if args.sys == '-':
score(sys.stdin)
else:
with open(args.sys, 'r') as f:
score(f)
if __name__ == '__main__':
main()
#!/usr/bin/env python3
import argparse
import collections
import torch
import os
import re
def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Args:
inputs: An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
for f in inputs:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state['model']
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
'For checkpoint {}, expected list of params: {}, '
'but found: {}'.format(f, params_keys, model_params_keys)
)
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
params_dict[k].append(p)
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items():
summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
new_state['model'] = averaged_params
return new_state
def last_n_checkpoints(paths, n, update_based):
assert len(paths) == 1
path = paths[0]
if update_based:
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)
entries = []
for f in files:
m = pt_regexp.fullmatch(f)
if m is not None:
entries.append((int(m.group(1)), m.group(0)))
if len(entries) < n:
raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
def main():
parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to '
'produce a new checkpoint',
)
parser.add_argument(
'--inputs',
required=True,
nargs='+',
help='Input checkpoint file paths.',
)
parser.add_argument(
'--output',
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
)
num_group = parser.add_mutually_exclusive_group()
num_group.add_argument(
'--num-epoch-checkpoints',
type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last this many of them.',
)
num_group.add_argument(
'--num-update-checkpoints',
type=int,
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.',
)
args = parser.parse_args()
print(args)
num = None
is_update_based = False
if args.num_update_checkpoints is not None:
num = args.num_update_checkpoints
is_update_based = True
elif args.num_epoch_checkpoints is not None:
num = args.num_epoch_checkpoints
if num is not None:
args.inputs = last_n_checkpoints(args.inputs, num, is_update_based)
print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs)
torch.save(new_state, args.output)
print('Finished writing averaged checkpoint to {}.'.format(args.output))
if __name__ == '__main__':
main()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
Use this script in order to build symmetric alignments for your translation
dataset.
This script depends on fast_align and mosesdecoder tools. You will need to
build those before running the script.
fast_align:
github: http://github.com/clab/fast_align
instructions: follow the instructions in README.md
mosesdecoder:
github: http://github.com/moses-smt/mosesdecoder
instructions: http://www.statmt.org/moses/?n=Development.GetStarted
The script produces the following files under --output_dir:
text.joined - concatenation of lines from the source_file and the
target_file.
align.forward - forward pass of fast_align.
align.backward - backward pass of fast_align.
aligned.sym_heuristic - symmetrized alignment.
"""
import argparse
import os
from itertools import zip_longest
def main():
parser = argparse.ArgumentParser(description='symmetric alignment builer')
parser.add_argument('--fast_align_dir',
help='path to fast_align build directory')
parser.add_argument('--mosesdecoder_dir',
help='path to mosesdecoder root directory')
parser.add_argument('--sym_heuristic',
help='heuristic to use for symmetrization',
default='grow-diag-final-and')
parser.add_argument('--source_file',
help='path to a file with sentences '
'in the source language')
parser.add_argument('--target_file',
help='path to a file with sentences '
'in the target language')
parser.add_argument('--output_dir',
help='output directory')
args = parser.parse_args()
fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align')
symal_bin = os.path.join(args.mosesdecoder_dir, 'bin', 'symal')
sym_fast_align_bin = os.path.join(
args.mosesdecoder_dir, 'scripts', 'ems',
'support', 'symmetrize-fast-align.perl')
# create joined file
joined_file = os.path.join(args.output_dir, 'text.joined')
with open(args.source_file, 'r') as src, open(args.target_file, 'r') as tgt:
with open(joined_file, 'w') as joined:
for s, t in zip_longest(src, tgt):
print('{} ||| {}'.format(s.strip(), t.strip()), file=joined)
bwd_align_file = os.path.join(args.output_dir, 'align.backward')
# run forward alignment
fwd_align_file = os.path.join(args.output_dir, 'align.forward')
fwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v > {FWD}'.format(
FASTALIGN=fast_align_bin,
JOINED=joined_file,
FWD=fwd_align_file)
assert os.system(fwd_fast_align_cmd) == 0
# run backward alignment
bwd_align_file = os.path.join(args.output_dir, 'align.backward')
bwd_fast_align_cmd = '{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}'.format(
FASTALIGN=fast_align_bin,
JOINED=joined_file,
BWD=bwd_align_file)
assert os.system(bwd_fast_align_cmd) == 0
# run symmetrization
sym_out_file = os.path.join(args.output_dir, 'aligned')
sym_cmd = '{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}'.format(
SYMFASTALIGN=sym_fast_align_bin,
FWD=fwd_align_file,
BWD=bwd_align_file,
SRC=args.source_file,
TGT=args.target_file,
OUT=sym_out_file,
HEURISTIC=args.sym_heuristic,
SYMAL=symal_bin
)
assert os.system(sym_cmd) == 0
if __name__ == '__main__':
main()
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Usage: convert_dictionary.lua <dict.th7>
require 'fairseq'
require 'torch'
require 'paths'
if #arg < 1 then
print('usage: convert_dictionary.lua <dict.th7>')
os.exit(1)
end
if not paths.filep(arg[1]) then
print('error: file does not exit: ' .. arg[1])
os.exit(1)
end
dict = torch.load(arg[1])
dst = paths.basename(arg[1]):gsub('.th7', '.txt')
assert(dst:match('.txt$'))
f = io.open(dst, 'w')
for idx, symbol in ipairs(dict.index_to_symbol) do
if idx > dict.cutoff then
break
end
f:write(symbol)
f:write(' ')
f:write(dict.index_to_freq[idx])
f:write('\n')
end
f:close()
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
-- Usage: convert_model.lua <model_epoch1.th7>
require 'torch'
local fairseq = require 'fairseq'
model = torch.load(arg[1])
function find_weight_norm(container, module)
for _, wn in ipairs(container:listModules()) do
if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then
return wn
end
end
end
function push_state(dict, key, module)
if torch.type(module) == 'nn.Linear' then
local wn = find_weight_norm(model.module, module)
assert(wn)
dict[key .. '.weight_v'] = wn.v:float()
dict[key .. '.weight_g'] = wn.g:float()
elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then
local wn = find_weight_norm(model.module, module)
assert(wn)
local v = wn.v:float():view(wn.viewOut):transpose(2, 3)
dict[key .. '.weight_v'] = v
dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1)
else
dict[key .. '.weight'] = module.weight:float()
end
if module.bias then
dict[key .. '.bias'] = module.bias:float()
end
end
encoder_dict = {}
decoder_dict = {}
combined_dict = {}
function encoder_state(encoder)
luts = encoder:findModules('nn.LookupTable')
push_state(encoder_dict, 'embed_tokens', luts[1])
push_state(encoder_dict, 'embed_positions', luts[2])
fcs = encoder:findModules('nn.Linear')
assert(#fcs >= 2)
local nInputPlane = fcs[1].weight:size(1)
push_state(encoder_dict, 'fc1', table.remove(fcs, 1))
push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs))
for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do
push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module)
if nInputPlane ~= module.weight:size(3) / 2 then
push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
end
nInputPlane = module.weight:size(3) / 2
end
assert(#fcs == 0)
end
function decoder_state(decoder)
luts = decoder:findModules('nn.LookupTable')
push_state(decoder_dict, 'embed_tokens', luts[1])
push_state(decoder_dict, 'embed_positions', luts[2])
fcs = decoder:findModules('nn.Linear')
local nInputPlane = fcs[1].weight:size(1)
push_state(decoder_dict, 'fc1', table.remove(fcs, 1))
push_state(decoder_dict, 'fc2', fcs[#fcs - 1])
push_state(decoder_dict, 'fc3', fcs[#fcs])
table.remove(fcs, #fcs)
table.remove(fcs, #fcs)
for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do
if nInputPlane ~= module.weight:size(3) / 2 then
push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
end
nInputPlane = module.weight:size(3) / 2
local prefix = 'attention.' .. tostring(i - 1)
push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1))
push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1))
push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module)
end
assert(#fcs == 0)
end
_encoder = model.module.modules[2]
_decoder = model.module.modules[3]
encoder_state(_encoder)
decoder_state(_decoder)
for k, v in pairs(encoder_dict) do
combined_dict['encoder.' .. k] = v
end
for k, v in pairs(decoder_dict) do
combined_dict['decoder.' .. k] = v
end
torch.save('state_dict.t7', combined_dict)
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from setuptools import setup, find_packages, Extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
import sys
if sys.version_info < (3,):
sys.exit('Sorry, Python3 is required for fairseq.')
with open('README.md') as f:
readme = f.read()
with open('LICENSE') as f:
license = f.read()
with open('requirements.txt') as f:
reqs = f.read()
bleu = Extension(
'fairseq.libbleu',
sources=[
'fairseq/clib/libbleu/libbleu.cpp',
'fairseq/clib/libbleu/module.cpp',
],
extra_compile_args=['-std=c++11'],
)
batch_utils_v0p5 = CppExtension(
name='fairseq.data.batch_C_v0p5',
sources=['fairseq/data/csrc/make_batches_v0p5.cpp'],
extra_compile_args={
'cxx': ['-O2',],
}
)
batch_utils_v0p5_better = CppExtension(
name='fairseq.data.batch_C_v0p5_better',
sources=['fairseq/data/csrc/make_batches_v0p5_better.cpp'],
extra_compile_args={
'cxx': ['-O2', '--std=c++14'],
}
)
batch_utils_v0p6 = CppExtension(
name='fairseq.data.batch_C_v0p6',
sources=['fairseq/data/csrc/make_batches_v0p6.cpp'],
extra_compile_args={
'cxx': ['-O2', '--std=c++14'],
}
)
setup(
name='fairseq',
version='0.5.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme,
license=license,
install_requires=reqs.strip().split('\n'),
packages=find_packages(),
ext_modules=[bleu, batch_utils_v0p5, batch_utils_v0p5_better, batch_utils_v0p6],
cmdclass={
'build_ext': BuildExtension
},
test_suite='tests',
)
python3 -u train.py /raw_data \
--seed 1234 \
--arch transformer_wmt_en_de_big_t2t \
--share-all-embeddings \
--optimizer adam \
--ignore-case
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import collections
import os
import tempfile
import unittest
import numpy as np
import torch
from scripts.average_checkpoints import average_checkpoints
class TestAverageCheckpoints(unittest.TestCase):
def test_average_checkpoints(self):
params_0 = collections.OrderedDict(
[
('a', torch.DoubleTensor([100.0])),
('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
('c', torch.IntTensor([7, 8, 9])),
]
)
params_1 = collections.OrderedDict(
[
('a', torch.DoubleTensor([1.0])),
('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
('c', torch.IntTensor([2, 2, 2])),
]
)
params_avg = collections.OrderedDict(
[
('a', torch.DoubleTensor([50.5])),
('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
# We expect truncation for integer division
('c', torch.IntTensor([4, 5, 5])),
]
)
fd_0, path_0 = tempfile.mkstemp()
fd_1, path_1 = tempfile.mkstemp()
torch.save(collections.OrderedDict([('model', params_0)]), path_0)
torch.save(collections.OrderedDict([('model', params_1)]), path_1)
output = average_checkpoints([path_0, path_1])['model']
os.close(fd_0)
os.remove(path_0)
os.close(fd_1)
os.remove(path_1)
for (k_expected, v_expected), (k_out, v_out) in zip(
params_avg.items(), output.items()):
self.assertEqual(
k_expected, k_out, 'Key mismatch - expected {} but found {}. '
'(Expected list of keys: {} vs actual list of keys: {})'.format(
k_expected, k_out, params_avg.keys(), output.keys()
)
)
np.testing.assert_allclose(
v_expected.numpy(),
v_out.numpy(),
err_msg='Tensor value mismatch for key {}'.format(k_expected)
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
from io import StringIO
import os
import random
import sys
import tempfile
import unittest
import torch
from fairseq import options
import preprocess
import train
import generate
import interactive
import eval_lm
class TestTranslation(unittest.TestCase):
def test_fconv(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en')
generate_main(data_dir)
def test_raw(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--output-format', 'raw'])
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--raw-text'])
generate_main(data_dir, ['--raw-text'])
def test_fp16(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fp16') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
generate_main(data_dir)
def test_update_freq(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3'])
generate_main(data_dir)
def test_lstm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lstm') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en', [
'--encoder-layers', '2',
'--decoder-layers', '2',
])
generate_main(data_dir)
def test_lstm_bidirectional(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lstm_bidirectional') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'lstm', [
'--encoder-layers', '2',
'--encoder-bidirectional',
'--encoder-hidden-size', '256',
'--decoder-layers', '2',
])
generate_main(data_dir)
def test_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'transformer_iwslt_de_en')
generate_main(data_dir)
class TestStories(unittest.TestCase):
def test_fconv_self_att_wp(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_self_att_wp') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
config = [
'--encoder-layers', '[(512, 3)] * 2',
'--decoder-layers', '[(512, 3)] * 2',
'--decoder-attention', 'True',
'--encoder-attention', 'False',
'--gated-attention', 'True',
'--self-attention', 'True',
'--project-input', 'True',
]
train_translation_model(data_dir, 'fconv_self_att_wp', config)
generate_main(data_dir)
# fusion model
os.rename(os.path.join(data_dir, 'checkpoint_last.pt'), os.path.join(data_dir, 'pretrained.pt'))
config.extend([
'--pretrained', 'True',
'--pretrained-checkpoint', os.path.join(data_dir, 'pretrained.pt'),
'--save-dir', os.path.join(data_dir, 'fusion_model'),
])
train_translation_model(data_dir, 'fconv_self_att_wp', config)
class TestLanguageModeling(unittest.TestCase):
def test_fconv_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'fconv_lm')
eval_lm_main(data_dir)
def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
def _create_dummy_data(filename):
data = torch.rand(num_examples * maxlen)
data = 97 + torch.floor(26 * data).int()
with open(os.path.join(data_dir, filename), 'w') as h:
offset = 0
for _ in range(num_examples):
ex_len = random.randint(1, maxlen)
ex_str = ' '.join(map(chr, data[offset:offset+ex_len]))
print(ex_str, file=h)
offset += ex_len
_create_dummy_data('train.in')
_create_dummy_data('train.out')
_create_dummy_data('valid.in')
_create_dummy_data('valid.out')
_create_dummy_data('test.in')
_create_dummy_data('test.out')
def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args(
[
'--source-lang', 'in',
'--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'),
'--validpref', os.path.join(data_dir, 'valid'),
'--testpref', os.path.join(data_dir, 'test'),
'--thresholdtgt', '0',
'--thresholdsrc', '0',
'--destdir', data_dir,
] + (extra_flags or []),
)
preprocess.main(preprocess_args)
def train_translation_model(data_dir, arch, extra_flags=None):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'translation',
data_dir,
'--save-dir', data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '0.05',
'--max-tokens', '500',
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
'--source-lang', 'in',
'--target-lang', 'out',
] + (extra_flags or []),
)
train.main(train_args)
def generate_main(data_dir, extra_flags=None):
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '3',
'--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
'--print-alignment',
] + (extra_flags or []),
)
# evaluate model in batch mode
generate.main(generate_args)
# evaluate model interactively
generate_args.buffer_size = 0
generate_args.max_sentences = None
orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n')
interactive.main(generate_args)
sys.stdin = orig_stdin
def preprocess_lm_data(data_dir):
preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args([
'--only-source',
'--trainpref', os.path.join(data_dir, 'train.out'),
'--validpref', os.path.join(data_dir, 'valid.out'),
'--testpref', os.path.join(data_dir, 'test.out'),
'--destdir', data_dir,
])
preprocess.main(preprocess_args)
def train_language_model(data_dir, arch):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'language_modeling',
data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '1.0',
'--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--tokens-per-sample', '500',
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
],
)
train.main(train_args)
def eval_lm_main(data_dir):
eval_lm_parser = options.get_eval_lm_parser()
eval_lm_args = options.parse_args_and_arch(
eval_lm_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--no-progress-bar',
],
)
eval_lm.main(eval_lm_args)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import unittest
from fairseq.modules import ConvTBC
import torch.nn as nn
class TestConvTBC(unittest.TestCase):
def test_convtbc(self):
# ksz, in_channels, out_channels
conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1)
# out_channels, in_channels, ksz
conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1)
conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2))
conv_tbc.bias.data.copy_(conv1d.bias.data)
input_tbc = torch.randn(7, 2, 4, requires_grad=True)
input1d = input_tbc.data.transpose(0, 1).transpose(1, 2)
input1d.requires_grad = True
output_tbc = conv_tbc(input_tbc)
output1d = conv1d(input1d)
self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
grad_tbc = torch.randn(output_tbc.size())
grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
output_tbc.backward(grad_tbc)
output1d.backward(grad1d)
self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data)
self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data)
self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
from fairseq.data import data_utils
class TestDataUtils(unittest.TestCase):
def test_counting_iterator(self):
x = list(range(10))
itr = data_utils.CountingIterator(x)
self.assertTrue(itr.has_next())
self.assertEqual(next(itr), 0)
self.assertEqual(next(itr), 1)
itr.skip(3)
self.assertEqual(next(itr), 5)
itr.skip(3)
self.assertEqual(next(itr), 9)
self.assertFalse(itr.has_next())
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import tempfile
import unittest
import torch
from fairseq.data import Dictionary
from fairseq.tokenizer import Tokenizer
class TestDictionary(unittest.TestCase):
def test_finalize(self):
txt = [
'A B C D',
'B C D',
'C D',
'D',
]
ref_ids1 = list(map(torch.IntTensor, [
[4, 5, 6, 7, 2],
[5, 6, 7, 2],
[6, 7, 2],
[7, 2],
]))
ref_ids2 = list(map(torch.IntTensor, [
[7, 6, 5, 4, 2],
[6, 5, 4, 2],
[5, 4, 2],
[4, 2],
]))
# build dictionary
d = Dictionary()
for line in txt:
Tokenizer.tokenize(line, d, add_if_not_exist=True)
def get_ids(dictionary):
ids = []
for line in txt:
ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False))
return ids
def assertMatch(ids, ref_ids):
for toks, ref_toks in zip(ids, ref_ids):
self.assertEqual(toks.size(), ref_toks.size())
self.assertEqual(0, (toks != ref_toks).sum().item())
ids = get_ids(d)
assertMatch(ids, ref_ids1)
# check finalized dictionary
d.finalize()
finalized_ids = get_ids(d)
assertMatch(finalized_ids, ref_ids2)
# write to disk and reload
with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
d.save(tmp_dict.name)
d = Dictionary.load(tmp_dict.name)
reload_ids = get_ids(d)
assertMatch(reload_ids, ref_ids2)
assertMatch(finalized_ids, reload_ids)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
import copy
import unittest
import torch
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
import tests.utils as test_utils
class TestLabelSmoothing(unittest.TestCase):
def setUp(self):
# build dictionary
self.d = test_utils.dummy_dictionary(3)
vocab = len(self.d)
self.assertEqual(vocab, 4 + 3) # 4 special + 3 tokens
self.assertEqual(self.d.pad(), 1)
self.assertEqual(self.d.eos(), 2)
self.assertEqual(self.d.unk(), 3)
pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6 # noqa: F841
# build dataset
self.data = [
# the first batch item has padding
{'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, eos])},
{'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, w1, eos])},
]
self.sample = next(test_utils.dummy_dataloader(self.data))
# build model
self.args = argparse.Namespace()
self.args.sentence_avg = False
self.args.probs = torch.FloatTensor([
# pad eos unk w1 w2 w3
[0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d)
self.model = self.task.build_model(self.args)
def test_nll_loss(self):
self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
self.assertLess(abs(nll_loss - smooth_logging_output['nll_loss']), 1e-6)
def test_padding(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample)
def get_one_no_padding(idx):
# create a new sample with just a single batch item so that there's
# no padding
sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
args1 = copy.copy(self.args)
args1.probs = args1.probs[idx, :, :].unsqueeze(0)
model1 = self.task.build_model(args1)
loss1, _, _ = crit(model1, sample1)
return loss1
loss1 = get_one_no_padding(0)
loss2 = get_one_no_padding(1)
self.assertAlmostEqual(loss, loss1 + loss2)
def test_reduction(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample, reduce=True)
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
self.assertAlmostEqual(loss, unreduced_loss.sum())
def test_zero_eps(self):
self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertAlmostEqual(nll_loss, smooth_loss)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-6)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
import unittest
import torch
from fairseq.sequence_generator import SequenceGenerator
import tests.utils as test_utils
class TestSequenceGenerator(unittest.TestCase):
def setUp(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.eos = d.eos()
self.w1 = 4
self.w2 = 5
# construct source data
self.src_tokens = torch.LongTensor([
[self.w1, self.w2, self.eos],
[self.w1, self.w2, self.eos],
])
self.src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
unk = 0.
args.beam_probs = [
# step 0:
torch.FloatTensor([
# eos w1 w2
# sentence 1:
[0.0, unk, 0.9, 0.1], # beam 1
[0.0, unk, 0.9, 0.1], # beam 2
# sentence 2:
[0.0, unk, 0.7, 0.3],
[0.0, unk, 0.7, 0.3],
]),
# step 1:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[0.0, unk, 0.9, 0.1], # w2: 0.1
# sentence 2:
[0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[0.00, unk, 0.10, 0.9], # w2: 0.3
]),
# step 2:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
[0.6, unk, 0.2, 0.2], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[0.60, unk, 0.4, 0.00], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
]),
# step 3:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[1.0, unk, 0.0, 0.0], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[0.1, unk, 0.5, 0.4], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[1.0, unk, 0.0, 0.0], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
]),
]
task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_with_normalization(self):
generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6])
def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False)
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], normalized=False)
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], normalized=False)
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], normalized=False)
def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0], lenpen=lenpen)
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], lenpen=lenpen)
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w1, eos])
self.assertHypoScore(hypos[0][1], [0.9, 1.0], lenpen=lenpen)
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self):
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.1, 0.6])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w2, w2, eos])
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self):
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w2, w2, w2, w2, eos])
self.assertHypoScore(hypos[1][0], [0.3, 0.9, 0.99, 0.4, 1.0])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0])
def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
pos_scores = torch.FloatTensor(pos_probs).log()
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel()**lenpen
self.assertLess(abs(score - hypo['score']), 1e-6)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment