"docs/source/en/api/schedulers/multistep_dpm_solver.md" did not exist on "75d53cc83966b4046e5a329ddf7baa6aa24f52e2"
Commit c0f05c10 authored by hepj's avatar hepj
Browse files

更新transformer代码

parent c056df78
# Contributing to FAIR Sequence-to-Sequence Toolkit (PyTorch)
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `master`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
## Coding Style
We try to follow the PEP style guidelines and encourage you to as well.
## License
By contributing to FAIR Sequence-to-Sequence Toolkit, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
\ No newline at end of file
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.05-py3
FROM ${FROM_IMAGE_NAME}
WORKDIR /workspace
#RUN git clone https://github.com/NVIDIA/apex \
# && cd apex \
# && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
# Install Python dependencies
RUN pip install --no-cache-dir \
sacrebleu \
sentencepiece
RUN pip install jupyter
ARG DEBIAN_FRONTEND=noninteractive
RUN apt-get update
RUN apt-get install -y -q cmake pkg-config protobuf-compiler libprotobuf-dev libgoogle-perftools-dev
RUN git clone https://github.com/google/sentencepiece.git /workspace/sentencepiece
RUN cd /workspace/sentencepiece \
&& git checkout d4dd947 \
&& mkdir build \
&& cd build \
&& cmake .. \
&& make -j 8 \
&& make install \
&& ldconfig -v
ENV PYTHONPATH=/workspace/translation/examples/translation/subword-nmt/
WORKDIR /workspace/translation
RUN git clone https://github.com/rsennrich/subword-nmt.git /workspace/translation/examples/translation/subword-nmt/
RUN git clone https://github.com/NVIDIA/cutlass.git && cd cutlass && git checkout ed2ed4d6 && cd ..
COPY . .
RUN pip install -e .
RUN pip install git+https://github.com/NVIDIA/dllogger@v0.1.0#egg=dllogger
BSD License
For fairseq software
Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
Copyright (c) 2019-present, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name Facebook nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Transformer PyTorch
This repository includes software from https://github.com/facebookresearch/fairseq
licensed under the BSD License.
Additional Grant of Patent Rights Version 2
"Software" means the fairseq software distributed by Facebook, Inc.
Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
(subject to the termination provision below) license under any Necessary
Claims, to make, have made, use, sell, offer to sell, import, and otherwise
transfer the Software. For avoidance of doubt, no license is granted under
Facebook’s rights in any patent claims that are infringed by (i) modifications
to the Software made by you or any third party or (ii) the Software in
combination with any software or other technology.
The license granted hereunder will terminate, automatically and without notice,
if you (or any of your subsidiaries, corporate affiliates or agents) initiate
directly or indirectly, or take a direct financial interest in, any Patent
Assertion: (i) against Facebook or any of its subsidiaries or corporate
affiliates, (ii) against any party if such Patent Assertion arises in whole or
in part from any software, technology, product or service of Facebook or any of
its subsidiaries or corporate affiliates, or (iii) against any party relating
to the Software. Notwithstanding the foregoing, if Facebook or any of its
subsidiaries or corporate affiliates files a lawsuit alleging patent
infringement against you in the first instance, and you respond by filing a
patent infringement counterclaim in that lawsuit against that party that is
unrelated to the Software, the license granted hereunder will not terminate
under section (i) of this paragraph due to such counterclaim.
A "Necessary Claim" is a claim of a patent owned by Facebook that is
necessarily infringed by the Software standing alone.
A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
or contributory infringement or inducement to infringe any patent, including a
cross-claim or counterclaim.
This diff is collapsed.
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import socket
import subprocess
from train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
args.device_id = int(os.environ.get('LOCAL_RANK', args.local_rank))
print('| initialized host {} as rank {} and device id {}'.format(socket.gethostname(), args.distributed_rank, args.device_id))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .multiprocessing_pdb import pdb
__all__ = ['pdb']
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
class CrossEntropyCriterion(_Loss):
def __init__(self, args):
super().__init__()
self.padding_idx = args.padding_idx
def forward(self, norm_probs, target, reduce=True):
"""Compute the loss for the given sample.
"""
lprobs = norm_probs.view(-1, norm_probs.size(-1))
target = target.view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
return loss
class LabelSmoothedCrossEntropyCriterion(_Loss):
def __init__(self, args):
super().__init__()
self.eps = args.label_smoothing
self.padding_idx = args.padding_idx
def forward(self, norm_probs, target, reduce=True):
"""Compute the loss for the given sample.
"""
target = target.view(-1, 1)
lprobs = norm_probs.view(-1, norm_probs.size(-1))
non_pad_mask = target.ne(self.padding_idx)
nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = self.eps / lprobs.size(-1)
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss
CRITERION_REGISTRY = {
'label_smoothed_cross_entropy' : LabelSmoothedCrossEntropyCriterion,
'cross_entropy' : CrossEntropyCriterion,
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dictionary import Dictionary
from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset # noqa: F401
from .language_pair_dataset import LanguagePairDataset, load_dataset_splits
from .data_utils import EpochBatchIterator
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace at { namespace native {
namespace {
bool is_batch_full(int64_t num_tokens, int64_t max_tokens, int64_t max_sentences, int64_t batch_length){
if (batch_length == 0){
return false;
} else if (batch_length == max_sentences || num_tokens > max_tokens){
return true;
} else {
return false;
}
}
}
std::vector<std::vector<int64_t> > make_batches(py::array_t<int64_t> src_lengths, py::array_t<int64_t> tgt_lengths, py::array_t<int64_t> idx_list, int64_t max_tokens, int64_t max_sentences, uint64_t bsz_mult, int64_t max_len){
std::vector<std::vector<int64_t> > batches;
auto src_l = src_lengths.unchecked<1>();
auto tgt_l = tgt_lengths.unchecked<1>();
auto idx_l = idx_list.unchecked<1>();
AT_ASSERTM(src_l.shape(0) == tgt_l.shape(0), "tgt_list and src_list should have the same shape");
AT_ASSERTM(idx_l.shape(0) == tgt_l.shape(0), "idx_list and tgt_list should have the same shape");
ssize_t nelem = src_l.shape(0);
int64_t sample_len =0;
std::vector<int64_t> sample_lens;
std::vector<int64_t> batch;
for (ssize_t i=0; i < nelem; i++){
int64_t idx = idx_l(i);
int64_t sample_num_tokens = std::max(src_l(idx), tgt_l(idx));
if (sample_num_tokens > max_len) continue;
sample_len = std::max(sample_len, sample_num_tokens);
sample_lens.push_back(sample_num_tokens);
int64_t num_tokens = (batch.size() + 1) * sample_len;
if (is_batch_full(num_tokens, max_tokens, max_sentences, batch.size())){
int64_t mode_len = std::max(batch.size() / bsz_mult * bsz_mult, batch.size() % bsz_mult);
std::vector<int64_t> new_batch;
new_batch.reserve(mode_len);
std::copy(batch.begin()+mode_len, batch.end(), std::back_inserter(new_batch));
batch.erase(batch.begin()+mode_len, batch.end());
sample_lens.erase(sample_lens.begin(), sample_lens.begin()+mode_len);
//sample_len always contains at least one element
sample_len = *std::max_element(sample_lens.begin(), sample_lens.end());
batches.push_back(batch);
batch = new_batch;
}
batch.push_back(idx);
}
if (batch.size() > 0) batches.push_back(batch);
return batches;
}
}}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("make_batches", &at::native::make_batches);
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import itertools
import os
import numpy as np
import torch
import fairseq.data.batch_C
import sys
from .dictionary import Dictionary
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
parts = filename.split('.')
if len(parts) >= 3 and len(parts[1].split('-')) == 2:
return parts[1].split('-')
return src, dst
def load_dictionaries(args):
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = infer_language_pair(args.data)
if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
args.src_vocab_size = len(src_dict)
args.tgt_vocab_size = len(tgt_dict)
args.padding_idx = src_dict.pad()
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))
return src_dict, tgt_dict
class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length)."""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count."""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
return self.count < len(self)
def skip(self, num_to_skip):
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False, pad_sequence=1):
"""Convert a list of 1d tensors into a padded 2d tensor."""
#size = max(v.size(0) for v in values)
orig_size = max(v.size(0) for v in values)
size = 0
if pad_sequence > 1:
size = orig_size // pad_sequence * pad_sequence
if orig_size % pad_sequence > 0:
size += pad_sequence
else:
size = orig_size
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
return res
def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, pad_sequence=1):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
pad_sequence,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=left_pad_target,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
else:
ntokens = sum(len(s['source']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
def get_dummy_batch(num_tokens, src_dict, tgt_dict, src_len=128, tgt_len=128,
left_pad_source=True, left_pad_target=False, pad_sequence=1):
bsz = num_tokens // max(src_len, tgt_len)
dummy_samples = [
{
'id': i,
'source': src_dict.dummy_sentence(src_len),
'target': tgt_dict.dummy_sentence(tgt_len) if tgt_dict is not None else None,
}
for i in range(bsz)
]
return collate(
dummy_samples, pad_idx=src_dict.pad(), eos_idx=src_dict.eos(),
left_pad_source=left_pad_source, left_pad_target=left_pad_target,
pad_sequence=pad_sequence,
)
class EpochBatchIterator(object):
"""Iterate over a FairseqDataset and yield batches bucketed by size.
Batches may contain sequences of different lengths. This iterator can be
reused across multiple epochs with the next_epoch_itr() method.
Args:
dataset: a FairseqDataset
max_tokens: max number of tokens in each batch
max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
required_batch_size_multiple: require batch size to be a multiple of N
seed: seed for random number generator for reproducibility
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
"""
def __init__(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
required_batch_size_multiple=1, seed=1,
num_shards=1, shard_id=0, epoch=0
):
self.dataset = dataset
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
self.max_positions = max_positions
self.bsz_mult = required_batch_size_multiple
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.epoch = epoch
self._cur_epoch_itr = None
self._next_epoch_itr = None
with numpy_seed(self.seed):
indices = self.dataset.ordered_indices(self.seed, self.epoch)
#need integer, rather than float('Inf') values
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
max_positions_num = 1024
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
#Following line is workaround due to the fact we cannot pass None object as argument
tgt_sizes = self.dataset.tgt_sizes if self.dataset.tgt_sizes is not None else self.dataset.src_sizes
batches = fairseq.data.batch_C.make_batches(
self.dataset.src_sizes, tgt_sizes, indices, max_tokens,
max_sentences, self.bsz_mult, max_positions_num)
self.frozen_batches = tuple(batches)
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""Shuffle batches and return a new iterator over the dataset."""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.dataset.collater,
num_workers=1,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.utils.data
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
raise NotImplementedError
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
raise NotImplementedError
def ordered_indices(self, seed=None, epoch=0):
"""Ordered indices for batching."""
raise NotImplementedError
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import struct
import numpy as np
import torch
from fairseq.tokenizer import Tokenizer
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False):
super().__init__()
self.fix_lua_indexing = fix_lua_indexing
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == b'TNTIDX\x00\x00'
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1)
self.sizes = read_longs(f, self.s)
self.read_data(path)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __del__(self):
self.data_file.close()
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
def __len__(self):
return self.size
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
)
class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = []
self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary)
self.size = len(self.tokens_list)
def read_data(self, path, dictionary):
with open(path, 'r') as f:
for line in f:
self.lines.append(line.strip('\n'))
tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
def get_original_text(self, i):
self.check_index(i)
return self.lines[i]
def __del__(self):
pass
def __len__(self):
return self.size
@staticmethod
def exists(path):
return os.path.exists(path)
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
def add_item(self, tensor):
# +1 for Lua compatibility
bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
index.close()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from torch.utils.data import Dataset, ConcatDataset
from . import data_utils
import itertools
import os
import sys
from fairseq.data import IndexedInMemoryDataset, IndexedRawTextDataset
class LanguagePairDataset(Dataset):
"""A pair of torch.utils.data.Datasets."""
def __init__(
self, src, src_sizes, src_dict,
tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024,
pad_sequence=1, shuffle=True,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
self.src = src
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.pad_sequence = pad_sequence
self.shuffle = shuffle
print("| Sentences are being padded to multiples of: {}".format(self.pad_sequence), file=sys.stderr)
def __getitem__(self, index):
return {
'id': index,
'source': self.src[index],
'target': self.tgt[index] if self.tgt is not None else None,
}
def __len__(self):
return len(self.src)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
return data_utils.collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
pad_sequence=self.pad_sequence,
)
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
orig_size = max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
assert self.pad_sequence > 0, "Padding multiple has to be greater than 0"
size = 0
if self.pad_sequence > 1:
size = orig_size // self.pad_sequence * self.pad_sequence
if orig_size % self.pad_sequence > 0:
size += self.pad_sequence
else:
size = orig_size
return size
#return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
def ordered_indices(self, seed=None, epoch=1):
"""Ordered indices for batching."""
if self.shuffle:
indices = np.random.RandomState(seed + epoch).permutation(len(self))
else:
indices = np.arange(len(self))
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
return (
self.src_sizes[index] <= max_source_positions and
(self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
)
def _get_max_positions(self, max_positions):
if max_positions is None:
return self.max_source_positions, self.max_target_positions
assert len(max_positions) == 2
max_src_pos, max_tgt_pos = max_positions
return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)
def load_dataset(args, datasets, split, src_dict, tgt_dict, combine=False):
"""Load a dataset split."""
def split_exists(split, src, tgt, lang):
filename = os.path.join(args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not args.raw_text and IndexedInMemoryDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path, fix_lua_indexing=True)
return None
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = args.source_lang, args.target_lang
if split_exists(split_k, src, tgt, src):
prefix = os.path.join(args.data, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src):
prefix = os.path.join(args.data, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, args.data))
src_datasets.append(indexed_dataset(prefix + src, src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, tgt_dict))
print('| {} {} {} examples'.format(args.data, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes
else:
src_dataset = ConcatDataset(src_datasets)
tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])
datasets[split] = LanguagePairDataset(
src_dataset, src_sizes, src_dict,
tgt_dataset, tgt_sizes, tgt_dict,
left_pad_source=args.left_pad_source,
left_pad_target=args.left_pad_target,
max_source_positions=args.max_source_positions,
max_target_positions=args.max_target_positions,
pad_sequence=args.pad_sequence,
)
def load_dataset_splits(args, splits, src_dict, tgt_dict):
datasets = {}
for split in splits:
if split == 'train':
load_dataset(args, datasets, split, src_dict, tgt_dict, combine=True)
else:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
load_dataset(args, datasets, split_k, src_dict, tgt_dict, combine=False)
except FileNotFoundError as e:
if k > 0:
break
raise e
return datasets
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
class TokenBlockDataset(torch.utils.data.Dataset):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
Args:
tokens: 1d tensor of tokens to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
"""
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
super().__init__()
self.tokens = tokens
self.total_size = len(tokens)
self.include_targets = include_targets
self.slice_indices = []
if break_mode is None or break_mode == 'none':
length = math.ceil(len(tokens) / block_size)
def block_at(i):
start = i * block_size
end = min(start + block_size, len(tokens))
return (start, end)
self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes[sz_idx]
sz_idx += 1
else:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
curr = 0
for sz in sizes:
# skip samples with just 1 example (which would be just the eos token)
if sz > 1:
self.slice_indices.append((curr, curr + sz))
curr += sz
else:
raise ValueError('Invalid break_mode: ' + break_mode)
self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index):
s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e])
if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
else:
source = self.tokens[s - 1:e - 1]
return torch.LongTensor(source), item
return item
def __len__(self):
return len(self.slice_indices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment