Commit e734b0fa authored by Sergey Edunov's avatar Sergey Edunov
Browse files

Initial commit

parents
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Checkpoints
checkpoints
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
BSD License
For fairseq software
Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name Facebook nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Additional Grant of Patent Rights Version 2
"Software" means the fairseq software distributed by Facebook, Inc.
Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
(subject to the termination provision below) license under any Necessary
Claims, to make, have made, use, sell, offer to sell, import, and otherwise
transfer the Software. For avoidance of doubt, no license is granted under
Facebook’s rights in any patent claims that are infringed by (i) modifications
to the Software made by you or any third party or (ii) the Software in
combination with any software or other technology.
The license granted hereunder will terminate, automatically and without notice,
if you (or any of your subsidiaries, corporate affiliates or agents) initiate
directly or indirectly, or take a direct financial interest in, any Patent
Assertion: (i) against Facebook or any of its subsidiaries or corporate
affiliates, (ii) against any party if such Patent Assertion arises in whole or
in part from any software, technology, product or service of Facebook or any of
its subsidiaries or corporate affiliates, or (iii) against any party relating
to the Software. Notwithstanding the foregoing, if Facebook or any of its
subsidiaries or corporate affiliates files a lawsuit alleging patent
infringement against you in the first instance, and you respond by filing a
patent infringement counterclaim in that lawsuit against that party that is
unrelated to the Software, the license granted hereunder will not terminate
under section (i) of this paragraph due to such counterclaim.
A "Necessary Claim" is a claim of a patent owned by Facebook that is
necessarily infringed by the Software standing alone.
A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
or contributory infringement or inducement to infringe any patent, including a
cross-claim or counterclaim.
# Introduction
FAIR Sequence-to-Sequence Toolkit (PyTorch)
This is a PyTorch version of [fairseq](https://github.com/facebookresearch/fairseq), a sequence-to-sequence learning toolkit from Facebook AI Research. The original authors of this reimplementation are (in no particular order) Sergey Edunov, Myle Ott, and Sam Gross. The toolkit implements the fully convolutional model described in [Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122). The toolkit features multi-GPU training on a single machine as well as fast beam search generation on both CPU and GPU. We provide pre-trained models for English to French and English to German translation.
![Model](fairseq.gif)
# Citation
If you use the code in your paper, then please cite it as:
```
@inproceedings{gehring2017convs2s,
author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N},
title = "{Convolutional Sequence to Sequence Learning}",
booktitle = {Proc. of ICML},
year = 2017,
}
```
# Requirements and Installation
* A computer running macOS or Linux
* For training new models, you'll also need a NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6
* A [PyTorch installation](http://pytorch.org/)
Currently fairseq-py requires PyTorch from the GitHub repository. There are multiple ways of installing it.
We suggest using [Miniconda3](https://conda.io/miniconda.html) and the following instructions.
* Install Miniconda3 from https://conda.io/miniconda.html create and activate python 3 environment.
```
conda install gcc numpy cudnn nccl
conda install magma-cuda80 -c soumith
pip install cmake
pip install cffi
git clone https://github.com/pytorch/pytorch.git
cd pytorch
git reset --hard a03e5cb40938b6b3f3e6dbddf9cff8afdff72d1b
git submodule update --init
pip install -r requirements.txt
NO_DISTRIBUTED=1 python setup.py install
```
Install fairseq by cloning the GitHub repository and by running
```
pip install -r requirements.txt
python setup.py build
python setup.py develop
```
The following command-line tools are available:
* `python preprocess.py`: Data pre-processing: build vocabularies and binarize training data
* `python train.py`: Train a new model on one or multiple GPUs
* `python generate.py`: Translate pre-processed data with a trained model
* `python generate.py -i`: Translate raw text with a trained model
* `python score.py`: BLEU scoring of generated translations against reference translations
# Quick Start
## Evaluating Pre-trained Models [TO BE ADAPTED]
First, download a pre-trained model along with its vocabularies:
```
$ curl https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.fconv-py.tar.bz2 | tar xvjf -
```
This model uses a [Byte Pair Encoding (BPE) vocabulary](https://arxiv.org/abs/1508.07909), so we'll have to apply the encoding to the source text before it can be translated.
This can be done with the [apply_bpe.py](https://github.com/rsennrich/subword-nmt/blob/master/apply_bpe.py) script using the `wmt14.en-fr.fconv-cuda/bpecodes` file.
`@@` is used as a continuation marker and the original text can be easily recovered with e.g. `sed s/@@ //g` or by passing the `--remove-bpe` flag to `generate.py`.
Prior to BPE, input text needs to be tokenized using `tokenizer.perl` from [mosesdecoder](https://github.com/moses-smt/mosesdecoder).
Let's use `python generate.py -i` to generate translations.
Here, we use a beam size of 5:
```
$ MODEL_DIR=wmt14.en-fr.fconv-py
$ python generate.py -i \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| model fconv_wmt_en_fr
| loaded checkpoint /private/home/edunov/wmt14.en-fr.fconv-py/model.pt (epoch 37)
> Why is it rare to discover new marine mam@@ mal species ?
S Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ?
H -0.08662842959165573 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 10 8 8 8 11 12
```
This generation script produces four types of outputs: a line prefixed with *S* shows the supplied source sentence after applying the vocabulary; *O* is a copy of the original source sentence; *H* is the hypothesis along with an average log-likelihood; and *A* is the attention maxima for each word in the hypothesis, including the end-of-sentence marker which is omitted from the text.
Check [below](#pre-trained-models) for a full list of pre-trained models available.
## Training a New Model
### Data Pre-processing
The fairseq source distribution contains an example pre-processing script for
the IWSLT 2014 German-English corpus.
Pre-process and binarize the data as follows:
```
$ cd data/
$ bash prepare-iwslt14.sh
$ cd ..
$ TEXT=data/iwslt14.tokenized.de-en
$ python preprocess.py --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--thresholdtgt 3 --thresholdsrc 3 --destdir data-bin/iwslt14.tokenized.de-en
```
This will write binarized data that can be used for model training to `data-bin/iwslt14.tokenized.de-en`.
### Training
Use `python train.py` to train a new model.
Here a few example settings that work well for the IWSLT 2014 dataset:
```
$ mkdir -p trainings/fconv
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
--lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--encoder-layers "[(256, 3)] * 4" --decoder-layers "[(256, 3)] * 3" \
--encoder-embed-dim 256 --decoder-embed-dim 256 --save-dir trainings/fconv
```
By default, `python train.py` will use all available GPUs on your machine.
Use the [CUDA_VISIBLE_DEVICES](http://acceleware.com/blog/cudavisibledevices-masking-gpus) environment variable to select specific GPUs and/or to change the number of GPU devices that will be used.
Also note that the batch size is specified in terms of the maximum number of tokens per batch (`--max-tokens`).
You may need to use a smaller value depending on the available GPU memory on your system.
### Generation
Once your model is trained, you can generate translations using `python generate.py` **(for binarized data)** or `python generate.py -i` **(for raw text)**:
```
$ python generate.py data-bin/iwslt14.tokenized.de-en \
--path trainings/fconv/checkpoint_best.pt \
--batch-size 128 --beam 5
| [de] dictionary: 35475 types
| [en] dictionary: 24739 types
| data-bin/iwslt14.tokenized.de-en test 6750 examples
| model fconv
| loaded checkpoint trainings/fconv/checkpoint_best.pt
S-721 danke .
T-721 thank you .
...
```
To generate translations with only a CPU, use the `--cpu` flag.
BPE continuation markers can be removed with the `--remove-bpe` flag.
# Pre-trained Models
We provide the following pre-trained fully convolutional sequence-to-sequence models:
* [wmt14.en-fr.fconv-py.tar.bz2](https://s3.amazonaws.com/faiseq-py/models/wmt14.en-fr.fconv-py.tar.bz2): Pre-trained model for [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) including vocabularies
* [wmt14.en-de.fconv-py.tar.bz2](https://s3.amazonaws.com/faiseq-py/models/wmt14.en-de.fconv-py.tar.bz2): Pre-trained model for [WMT14 English-German](https://nlp.stanford.edu/projects/nmt) including vocabularies
In addition, we provide pre-processed and binarized test sets for the models above:
* [wmt14.en-fr.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-French
* [wmt14.en-fr.ntst1213.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.ntst1213.tar.bz2): newstest2012 and newstest2013 test sets for WMT14 English-French
* [wmt14.en-de.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.en-de.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-German
Generation with the binarized test sets can be run in batch mode as follows, e.g. for English-French on a GTX-1080ti:
```
$ curl https://s3.amazonaws.com/faiseq-py/models/wmt14.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
$ curl https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
$ python generate.py data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
...
| Translated 3003 sentences (95451 tokens) in 136.3s (700.49 tokens/s)
| Timings: setup 0.1s (0.1%), encoder 1.9s (1.4%), decoder 108.9s (79.9%), search_results 0.0s (0.0%), search_prune 12.5s (9.2%)
| BLEU4 = 43.43, 68.2/49.2/37.4/28.8 (BP=0.996, ratio=1.004, sys_len=92087, ref_len=92448)
# Word-level BLEU scoring:
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
TODO: update scores
BLEU4 = 40.55, 67.6/46.5/34.0/25.3 (BP=1.000, ratio=0.998, sys_len=81369, ref_len=81194)
```
# Join the fairseq community
* Facebook page: https://www.facebook.com/groups/fairseq.users
* Google group: https://groups.google.com/forum/#!forum/fairseq-users
# License
fairseq is BSD-licensed.
The license applies to the pre-trained models as well.
We also provide an additional patent grant.
#!/usr/bin/env bash
#
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
LC=$SCRIPTS/tokenizer/lowercase.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz"
GZ=de-en.tgz
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=de
tgt=en
lang=de-en
prep=iwslt14.tokenized.de-en
tmp=$prep/tmp
orig=orig
mkdir -p $orig $tmp $prep
echo "Downloading data from ${URL}..."
cd $orig
wget "$URL"
if [ -f $GZ ]; then
echo "Data successfully downloaded."
else
echo "Data not successfully downloaded."
exit
fi
tar zxvf $GZ
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
f=train.tags.$lang.$l
tok=train.tags.$lang.tok.$l
cat $orig/$lang/$f | \
grep -v '<url>' | \
grep -v '<talkid>' | \
grep -v '<keywords>' | \
sed -e 's/<title>//g' | \
sed -e 's/<\/title>//g' | \
sed -e 's/<description>//g' | \
sed -e 's/<\/description>//g' | \
perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
echo ""
done
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
for l in $src $tgt; do
perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
done
echo "pre-processing valid/test data..."
for l in $src $tgt; do
for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
fname=${o##*/}
f=$tmp/${fname%.*}
echo $o $f
grep '<seg id' $o | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -l $l | \
perl $LC > $f
echo ""
done
done
echo "creating train, valid, test..."
for l in $src $tgt; do
awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $prep/valid.$l
awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $prep/train.$l
cat $tmp/IWSLT14.TED.dev2010.de-en.$l \
$tmp/IWSLT14.TEDX.dev2012.de-en.$l \
$tmp/IWSLT14.TED.tst2010.de-en.$l \
$tmp/IWSLT14.TED.tst2011.de-en.$l \
$tmp/IWSLT14.TED.tst2012.de-en.$l \
> $prep/test.$l
done
fairseq.gif

2.54 MB

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .multiprocessing_pdb import pdb
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import ctypes
import math
import torch
try:
from fairseq import libbleu
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libbleu.so. run `python setup.py install`\n')
raise e
C = ctypes.cdll.LoadLibrary(libbleu.__file__)
class BleuStat(ctypes.Structure):
_fields_ = [
('reflen', ctypes.c_size_t),
('predlen', ctypes.c_size_t),
('match1', ctypes.c_size_t),
('count1', ctypes.c_size_t),
('match2', ctypes.c_size_t),
('count2', ctypes.c_size_t),
('match3', ctypes.c_size_t),
('count3', ctypes.c_size_t),
('match4', ctypes.c_size_t),
('count4', ctypes.c_size_t),
]
class Scorer(object):
def __init__(self, pad, eos, unk):
self.stat = BleuStat()
self.pad = pad
self.eos = eos
self.unk = unk
self.reset()
def reset(self, one_init=False):
if one_init:
C.bleu_one_init(ctypes.byref(self.stat))
else:
C.bleu_zero_init(ctypes.byref(self.stat))
def add(self, ref, pred):
if not isinstance(ref, torch.IntTensor):
raise TypeError('ref must be a torch.IntTensor (got {})'
.format(type(ref)))
if not isinstance(pred, torch.IntTensor):
raise TypeError('pred must be a torch.IntTensor(got {})'
.format(type(pred)))
assert self.unk > 0, 'unknown token index must be >0'
rref = ref.clone()
rref.apply_(lambda x: x if x != self.unk else -x)
rref = rref.contiguous().view(-1)
pred = pred.contiguous().view(-1)
C.bleu_add(
ctypes.byref(self.stat),
ctypes.c_size_t(rref.size(0)),
ctypes.c_void_p(rref.data_ptr()),
ctypes.c_size_t(pred.size(0)),
ctypes.c_void_p(pred.data_ptr()),
ctypes.c_int(self.pad),
ctypes.c_int(self.eos))
def score(self, order=4):
psum = sum(math.log(p) if p > 0 else float('-Inf')
for p in self.precision()[:order])
return self.brevity() * math.exp(psum / order) * 100
def precision(self):
def ratio(a, b):
return a / b if b > 0 else 0
return [
ratio(self.stat.match1, self.stat.count1),
ratio(self.stat.match2, self.stat.count2),
ratio(self.stat.match3, self.stat.count3),
ratio(self.stat.match4, self.stat.count4),
]
def brevity(self):
r = self.stat.reflen / self.stat.predlen
return min(1, math.exp(1 - r))
def result_string(self, order=4):
assert order <= 4, "BLEU scores for order > 4 aren't supported"
fmt = 'BLEU{} = {:2.2f}, {:2.1f}'
for i in range(1, order):
fmt += '/{:2.1f}'
fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})'
bleup = [p * 100 for p in self.precision()[:order]]
return fmt.format(order, self.score(order=order), *bleup,
self.brevity(), self.stat.reflen/self.stat.predlen,
self.stat.predlen, self.stat.reflen)
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <map>
#include <array>
#include <cstring>
#include <cstdio>
typedef struct
{
size_t reflen;
size_t predlen;
size_t match1;
size_t count1;
size_t match2;
size_t count2;
size_t match3;
size_t count3;
size_t match4;
size_t count4;
} bleu_stat;
// left trim (remove pad)
void bleu_ltrim(size_t* len, int** sent, int pad) {
size_t start = 0;
while(start < *len) {
if (*(*sent + start) != pad) { break; }
start++;
}
*sent += start;
*len -= start;
}
// right trim remove (eos)
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
size_t end = *len - 1;
while (end > 0) {
if (*(*sent + end) != eos && *(*sent + end) != pad) { break; }
end--;
}
*len = end + 1;
}
// left and right trim
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
bleu_ltrim(len, sent, pad);
bleu_rtrim(len, sent, pad, eos);
}
size_t bleu_hash(int len, int* data) {
size_t h = 14695981039346656037ul;
size_t prime = 0x100000001b3;
char* b = (char*) data;
size_t blen = sizeof(int) * len;
while (blen-- > 0) {
h ^= *b++;
h *= prime;
}
return h;
}
void bleu_addngram(
size_t *ntotal, size_t *nmatch, size_t n,
size_t reflen, int* ref, size_t predlen, int* pred) {
if (predlen < n) { return; }
predlen = predlen - n + 1;
(*ntotal) += predlen;
if (reflen < n) { return; }
reflen = reflen - n + 1;
std::map<size_t, size_t> count;
while (predlen > 0) {
size_t w = bleu_hash(n, pred++);
count[w]++;
predlen--;
}
while (reflen > 0) {
size_t w = bleu_hash(n, ref++);
if (count[w] > 0) {
(*nmatch)++;
count[w] -=1;
}
reflen--;
}
}
extern "C" {
void bleu_zero_init(bleu_stat* stat) {
std::memset(stat, 0, sizeof(bleu_stat));
}
void bleu_one_init(bleu_stat* stat) {
bleu_zero_init(stat);
stat->count1 = 1;
stat->count2 = 1;
stat->count3 = 1;
stat->count4 = 1;
stat->match1 = 1;
stat->match2 = 1;
stat->match3 = 1;
stat->match4 = 1;
}
void bleu_add(
bleu_stat* stat,
size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) {
bleu_trim(&reflen, &ref, pad, eos);
bleu_trim(&predlen, &pred, pad, eos);
stat->reflen += reflen;
stat->predlen += predlen;
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
}
}
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <Python.h>
static PyMethodDef method_def[] = {
{NULL, NULL, 0, NULL}
};
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"libbleu", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
method_def
};
#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC init_libbleu()
#else
PyMODINIT_FUNC PyInit_libbleu()
#endif
{
PyObject *m = PyModule_Create(&module_def);
if (!m) {
return NULL;
}
return m;
}
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdio.h>
#include <string.h>
#include <stdexcept>
#include <ATen/ATen.h>
using at::Tensor;
extern THCState* state;
at::Type& getDataType(const char* dtype) {
if (strcmp(dtype, "torch.cuda.FloatTensor") == 0) {
return at::getType(at::kCUDA, at::kFloat);
} else if (strcmp(dtype, "torch.FloatTensor") == 0) {
return at::getType(at::kCPU, at::kFloat);
} else {
throw std::runtime_error(std::string("Unsupported data type: ") + dtype);
}
}
inline at::Tensor t(at::Type& type, void* i) {
return type.unsafeTensorFromTH(i, true);
}
extern "C" void TemporalConvolutionTBC_forward(
const char* dtype,
void* _input,
void* _output,
void* _weight,
void* _bias)
{
auto& type = getDataType(dtype);
Tensor input = t(type, _input);
Tensor output = t(type, _output);
Tensor weight = t(type, _weight);
Tensor bias = t(type, _bias);
auto input_size = input.sizes();
auto output_size = output.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto outputPlanes = output_size[2];
auto olen = output_size[0];
auto kw = weight.sizes()[0];
int pad = (olen - ilen + kw - 1) / 2;
// input * weights + bias -> output_features
output.copy_(bias.expand(output.sizes()));
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// Note: gemm assumes column-major matrices
// input is l*m (row-major)
// weight is m*r (row-major)
// output is l*r (row-major)
if (t > 0) {
auto W = weight[k];
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
at::addmm_out(1, O, 1, I, W, O);
}
}
}
extern "C" void TemporalConvolutionTBC_backward(
const char* dtype,
void* _dOutput,
void* _dInput,
void* _dWeight,
void* _dBias,
void* _input,
void* _weight)
{
auto& type = getDataType(dtype);
Tensor dOutput = t(type, _dOutput);
Tensor dInput = t(type, _dInput);
Tensor dWeight = t(type, _dWeight);
Tensor dBias = t(type, _dBias);
Tensor input = t(type, _input);
Tensor weight = t(type, _weight);
auto input_size = input.sizes();
auto output_size = dOutput.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto outputPlanes = output_size[2];
auto olen = output_size[0];
auto kw = weight.sizes()[0];
int pad = (olen - ilen + kw - 1) / 2;
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// dOutput * T(weight) -> dInput
if (t > 0) {
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto dI = dInput.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
at::addmm_out(1, dI, 1, dO, weight[k].t(), dI);
}
}
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// T(input) * dOutput -> dWeight
if (t > 0) {
auto dW = dWeight[k];
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes}).t();
at::addmm_out(1, dW, 1, I, dO, dW);
}
}
auto tmp = dOutput.sum(0, false);
at::sum_out(tmp, 0, dBias);
}
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
void TemporalConvolutionTBC_forward(
const char* dtype,
void* input,
void* output,
void* weight,
void* bias);
void TemporalConvolutionTBC_backward(
const char* dtype,
void* _dOutput,
void* _dInput,
void* _dWeight,
void* _dBias,
void* _input,
void* _weight);
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .cross_entropy import CrossEntropyCriterion
from .fairseq_criterion import FairseqCriterion
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
__all__ = [
'CrossEntropyCriterion',
'LabelSmoothedCrossEntropyCriterion',
]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, padding_idx):
super().__init__()
self.padding_idx = padding_idx
def prepare(self, samples):
self.denom = sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, net_output, sample):
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
return loss / self.denom
def aggregate(self, losses):
return sum(losses) / math.log(2)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def prepare(self, samples):
"""Prepare criterion for DataParallel training."""
raise NotImplementedError
def forward(self, net_output, sample):
"""Compute the loss for the given sample and network output."""
raise NotImplementedError
def aggregate(self, losses):
"""Aggregate losses from DataParallel training.
Takes a list of losses as input (as returned by forward) and
aggregates them into the total loss for the mini-batch.
"""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
from torch.autograd.variable import Variable
import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion
class LabelSmoothedCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, eps, padding_idx, weights):
grad_input = input.new(input.size()).zero_()
target = target.view(target.size(0), 1)
grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1)
norm = grad_input.size(-1)
if weights is not None:
norm = weights.sum()
grad_input.mul(weights.view(1, weights.size(0)).expand_as(grad_input))
if padding_idx is not None:
norm -= 1 if weights is None else weights[padding_idx]
grad_input.select(grad_input.dim() - 1, padding_idx).fill_(0)
grad_input = grad_input.add(-eps / norm)
ctx.grad_input = grad_input
return input.new([grad_input.view(-1).dot(input.view(-1))])
@staticmethod
def backward(ctx, grad):
return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, eps, padding_idx=None, weights=None):
super().__init__()
self.eps = eps
self.padding_idx = padding_idx
self.weights = weights
def prepare(self, samples):
self.denom = sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, net_output, sample):
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
return loss / self.denom
def aggregate(self, losses):
return sum(losses) / math.log(2)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import contextlib
import itertools
import numpy as np
import os
import torch
import torch.utils.data
from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
def load_with_check(path, src=None, dst=None):
"""Loads the train, valid, and test sets from the specified folder
and check that training files exist."""
def find_language_pair(files):
for filename in files:
parts = filename.split('.')
if parts[0] == 'train' and parts[-1] == 'idx':
return parts[1].split('-')
def train_file_exists(src, dst):
filename = 'train.{0}-{1}.{0}.idx'.format(src, dst)
return os.path.exists(os.path.join(path, filename))
if src is None and dst is None:
# find language pair automatically
src, dst = find_language_pair(os.listdir(path))
elif train_file_exists(src, dst):
# check for src-dst langcode
pass
elif train_file_exists(dst, src):
# check for dst-src langcode
src, dst = dst, src
else:
raise ValueError('training file not found for {}-{}'.format(src, dst))
dataset = load(path, src, dst)
return dataset
def load(path, src, dst):
"""Loads the train, valid, and test sets from the specified folder."""
langcode = '{}-{}'.format(src, dst)
def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args))
src_dict = Dictionary.load(fmt_path('dict.{}.txt', src))
dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst))
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
for split in ['train', 'valid', 'test']:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
if not IndexedInMemoryDataset.exists(src_path):
break
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(fmt_path('{}.{}.{}', prefix, langcode, dst)),
padding_value=dataset.src_dict.pad(),
eos=dataset.src_dict.eos(),
)
return dataset
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024):
dataset = self.splits[split]
if split.startswith('train'):
with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size(
dataset.src, dataset.dst,
max_tokens=max_tokens, epoch=epoch,
sample=sample_without_replacement,
max_positions=max_positions)
elif split.startswith('valid'):
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst,
max_positions=max_positions))
else:
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions))
return torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=PaddingCollater(self.src_dict.pad()),
batch_sampler=batch_sampler)
def skip_group_enumerator(it, ngpus, offset=0):
res = []
idx = 0
for i, sample in enumerate(it):
if i < offset:
continue
res.append(sample)
if len(res) >= ngpus:
yield (i, res)
res = []
idx = i + 1
if len(res) > 0:
yield (idx, res)
class PaddingCollater(object):
def __init__(self, padding_value=1):
self.padding_value = padding_value
def __call__(self, samples):
def merge(key, pad_begin):
return self.merge_with_pad([s[key] for s in samples], pad_begin)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'input_tokens': merge('input_tokens', pad_begin=True),
'input_positions': merge('input_positions', pad_begin=True),
'target': merge('target', pad_begin=True),
'src_tokens': merge('src_tokens', pad_begin=False),
'src_positions': merge('src_positions', pad_begin=False),
'ntokens': ntokens,
}
def merge_with_pad(self, values, pad_begin):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(self.padding_value)
for i, v in enumerate(values):
if pad_begin:
res[i][size-len(v):].copy_(v)
else:
res[i][:len(v)].copy_(v)
return res
class LanguagePairDataset(object):
def __init__(self, src, dst, padding_value=1, eos=2):
self.src = src
self.dst = dst
self.padding_value = padding_value
self.eos = eos
def __getitem__(self, i):
src = self.src[i].long() - 1
target = self.dst[i].long() - 1
input = target.new(target.size())
input[0] = self.eos
input[1:].copy_(target[:-1])
return {
'id': i,
'input_tokens': input,
'input_positions': self.make_positions(input),
'target': target,
'src_tokens': src,
'src_positions': self.make_positions(src),
}
def make_positions(self, x):
start = self.padding_value + 1
return torch.arange(start, start + len(x)).type_as(x)
def __len__(self):
return len(self.src)
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positions=1024):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert isinstance(src, IndexedDataset)
assert dst is None or isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
sizes = src.sizes
indices = np.argsort(sizes, kind='mergesort')
if dst is not None:
sizes = np.maximum(sizes, dst.sizes)
batch = []
def yield_batch(next_idx, num_tokens):
if len(batch) == 0:
return False
if len(batch) == batch_size:
return True
if sizes[batch[0]] != sizes[next_idx]:
return True
if num_tokens >= max_tokens:
return True
return False
cur_max_size = 0
for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or \
(dst is not None and dst.sizes[idx] < 2) or \
sizes[idx] > max_positions - 2:
raise Exception("Unable to handle input id {} of "
"size {} / {}.".format(idx, src.sizes[idx], dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch
batch = []
cur_max_size = 0
batch.append(idx)
cur_max_size = max(cur_max_size, sizes[idx])
if len(batch) > 0:
yield batch
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_positions=1024):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
indices = np.random.permutation(len(src))
# sort by sizes
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
def make_batches():
batch = []
sample_len = 0
ignored = []
for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or dst.sizes[idx] < 2 or \
src.sizes[idx] > max_positions - 2 or \
dst.sizes[idx] > max_positions - 2:
ignored.append(idx)
continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
if len(batch) > 0 and (len(batch) + 1) * sample_len > max_tokens:
yield batch
batch = []
sample_len = max(src.sizes[idx], dst.sizes[idx])
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored))
batches = list(make_batches())
np.random.shuffle(batches)
if sample:
offset = (epoch - 1) * sample
while offset > len(batches):
np.random.shuffle(batches)
offset -= len(batches)
result = batches[offset:(offset + sample)]
while len(result) < sample:
np.random.shuffle(batches)
result += batches[:(sample - len(result))]
assert len(result) == sample, \
"batch length is not correct {}".format(len(result))
batches = result
else:
for i in range(epoch - 1):
np.random.shuffle(batches)
return batches
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.add_symbol('<Lua heritage>')
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
self.nspecial = len(self.symbols)
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def index(self, sym):
"""Returns the index of the specified symbol"""
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def string(self, tensor):
if torch.is_tensor(tensor) and tensor.dim() == 2:
sentences = [self.string(line) for line in tensor]
return '\n'.join(sentences)
eos = self.eos()
return ' '.join([self[i] for i in tensor if i != eos])
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def finalize(self):
"""Sort symbols by frequency in descending order, ignoring special ones."""
self.count, self.symbols = zip(
*sorted(zip(self.count, self.symbols),
key=(lambda x: math.inf if self.indices[x[1]] < self.nspecial else x[0]),
reverse=True)
)
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
def unk(self):
"""Helper to get index of unk symbol"""
return self.unk_index
@staticmethod
def load(f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
if isinstance(f, str):
with open(f, 'r') as fd:
return Dictionary.load(fd)
d = Dictionary()
for line in f.readlines():
idx = line.rfind(' ')
word = line[:idx]
count = int(line[idx+1:])
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
return d
def save(self, f, threshold=3, nwords=-1):
"""Stores dictionary into a text file"""
if isinstance(f, str):
with open(f, 'w') as fd:
return self.save(fd, threshold, nwords)
cnt = 0
for i, t in enumerate(zip(self.symbols, self.count)):
if i >= self.nspecial and t[1] >= threshold \
and (nwords < 0 or cnt < nwords):
print('{} {}'.format(t[0], t[1]), file=f)
cnt += 1
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import numpy as np
import os
import struct
import torch
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
class IndexedDataset(object):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path):
with open(path + '.idx', 'rb') as f:
magic = f.read(8)
assert magic == b'TNTIDX\x00\x00'
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1)
self.sizes = read_longs(f, self.s)
self.read_data(path)
def read_data(self, path):
self.data_file = open(path + '.bin', 'rb', buffering=0)
def __del__(self):
self.data_file.close()
def __getitem__(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return torch.from_numpy(a)
def __len__(self):
return self.size
@staticmethod
def exists(path):
return os.path.exists(path + '.idx')
class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(path + '.bin', 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
def __del__(self):
pass
def __getitem__(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a)
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
def add_item(self, tensor):
# +1 for Lua compatibility
bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype),
self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1,
len(self.sizes)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
index.close()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import time
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class TimeMeter(object):
"""Computes the average occurence of some event per second"""
def __init__(self):
self.reset()
def reset(self):
self.start = time.time()
self.n = 0
def update(self, val=1):
self.n += val
@property
def avg(self):
delta = time.time() - self.start
return self.n / delta
@property
def elapsed_time(self):
return time.time() - self.start
class StopwatchMeter(object):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self):
self.reset()
def start(self):
self.start_time = time.time()
def stop(self, n=1):
if self.start_time is not None:
delta = time.time() - self.start_time
self.sum += delta
self.n += n
self.start_time = None
def reset(self):
self.sum = 0
self.n = 0
self.start_time = None
@property
def avg(self):
return self.sum / self.n
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment