Commit 9e8a8c05 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
This diff is collapsed.
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import Counter
import os
import torch
# MLPerf compliant dictionary
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>_', eos='<EOS>_'):
self.pad_word, self.eos_word = pad, eos
self.symbols = []
self.count = []
self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
# Commented out and hard-coded since pad and eos are in the dictionary files already
self.add_symbol('<lua_index_compat>')
self.pad_index = 1
self.eos_index = 2
#self.pad_index = self.add_symbol(pad)
#self.eos_index = self.add_symbol(eos)
#self.add_symbol('<bypass_unk>')
self.nspecial = 3
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
else:
assert idx < len(self.symbols)
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def index(self, sym):
"""Returns the index of the specified symbol"""
if sym in self.indices:
return self.indices[sym]
else:
assert sym in self.indices
def string(self, tensor, bpe_symbol=None):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return '\n'.join(self.string(t) for t in tensor)
def token_string(i):
return self[i]
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None:
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords <= 0:
nwords = len(self)
new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial]
c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(count)
else:
break
threshold_nwords = len(new_symbols)
if padding_factor > 1:
i = 0
while threshold_nwords % padding_factor != 0:
symbol = 'madeupword{:04d}'.format(i)
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(0)
i += 1
threshold_nwords += 1
assert len(new_symbols) % padding_factor == 0
assert len(new_symbols) == len(new_indices)
self.count = list(new_count)
self.symbols = list(new_symbols)
self.indices = new_indices
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
@classmethod
def load(cls, f, ignore_utf_errors=False):
"""Loads the dictionary from a text file with the format:
```
<symbol0>
<symbol1>
...
```
"""
if isinstance(f, str):
try:
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
return cls.load(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
return cls.load(fd)
except FileNotFoundError as fnfe:
raise fnfe
except Exception:
raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f))
d = cls()
for line in f.readlines():
word = line.strip()[1:-1] ## Remove the single quotes
count = 1
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
n_pad_tokens_on_end = 33712 - len(d.symbols)
#assert n_pad_tokens_on_end == 3 ## DEBUG: remove later, sanity check
for i in range(n_pad_tokens_on_end):
pad_str = '<pad000' + str(i) + '>'
d.indices[pad_str] = len(d.symbols)
d.symbols.append(pad_str)
d.count.append(1)
return d
def save(self, f):
"""Stores dictionary into a text file"""
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
class Dictionary_fairseq(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
self.add_symbol('<Lua heritage>')
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def index(self, sym):
"""Returns the index of the specified symbol"""
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def string(self, tensor, bpe_symbol=None, escape_unk=False):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return '\n'.join(self.string(t) for t in tensor)
def token_string(i):
if i == self.unk():
return self.unk_string(escape_unk)
else:
return self[i]
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None:
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return '<{}>'.format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords <= 0:
nwords = len(self)
new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial]
c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(count)
else:
break
threshold_nwords = len(new_symbols)
if padding_factor > 1:
i = 0
while threshold_nwords % padding_factor != 0:
symbol = 'madeupword{:04d}'.format(i)
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(0)
i += 1
threshold_nwords += 1
assert len(new_symbols) % padding_factor == 0
assert len(new_symbols) == len(new_indices)
self.count = list(new_count)
self.symbols = list(new_symbols)
self.indices = new_indices
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
def unk(self):
"""Helper to get index of unk symbol"""
return self.unk_index
@classmethod
def load(cls, f, ignore_utf_errors=False):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
if isinstance(f, str):
try:
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
return cls.load(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
return cls.load(fd)
except FileNotFoundError as fnfe:
raise fnfe
except Exception:
raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f))
d = cls()
for line in f.readlines():
idx = line.rfind(' ')
word = line[:idx]
count = int(line[idx+1:])
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
return d
def save(self, f):
"""Stores dictionary into a text file"""
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.utils.data
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
raise NotImplementedError
def get_dummy_batch(self, num_tokens, max_positions):
"""Return a dummy batch with a given number of tokens."""
raise NotImplementedError
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
raise NotImplementedError
def ordered_indices(self, seed=None):
"""Ordered indices for batching."""
raise NotImplementedError
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
raise NotImplementedError
This diff is collapsed.
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, bsz_mult=8, seq_len_multiple=1):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx,
left_pad,
move_eos_to_beginning,
bsz_mult,
seq_len_multiple
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
prev_output_tokens = None
target = None
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=left_pad_target,
move_eos_to_beginning=True,
)
ntokens = sum(len(s['target']) for s in samples)
else:
ntokens = sum(len(s['source']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
class LanguagePairDataset(FairseqDataset):
"""A pair of torch.utils.data.Datasets."""
def __init__(
self,
src,
src_sizes,
src_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
left_pad_source=True,
left_pad_target=False,
max_source_positions=256,
max_target_positions=256,
seq_len_multiple=1,
shuffle=True
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
self.src = src
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.seq_len_multiple = seq_len_multiple
self.shuffle = shuffle
print("| Sentences are being padded to multiples of: {}".format(self.seq_len_multiple))
def __getitem__(self, index):
return {
'id': index,
'source': self.src[index],
'target': self.tgt[index] if self.tgt is not None else None,
}
def __len__(self):
return len(self.src)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
return collate(
samples,
pad_idx=self.src_dict.pad(),
eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source,
left_pad_target=self.left_pad_target,
bsz_mult=8,
seq_len_multiple=self.seq_len_multiple,
)
def get_dummy_batch(self, max_tokens_per_batch, max_positions, src_len=256, tgt_len=256):
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
src_len, tgt_len = min(src_len, max_source_positions), min(tgt_len, max_target_positions)
n_seq_per_batch_based_on_longest_seq = max_tokens_per_batch // max(src_len, tgt_len)
return self.collater([
{
'id': i,
'source': self.src_dict.dummy_sentence(src_len),
'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
}
for i in range(n_seq_per_batch_based_on_longest_seq)
])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching.
Args:
index: points to the sequence pair
"""
n_tok_per_seq = max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
assert self.seq_len_multiple > 0, "Padding multiple has to be greater than 0"
n_tok_per_seq = (n_tok_per_seq + self.seq_len_multiple - 1) // self.seq_len_multiple * self.seq_len_multiple # Padded seq len, rounded up to next multiple
return n_tok_per_seq
def ordered_indices(self, seed=None):
"""Ordered indices for batching."""
if self.shuffle:
indices = np.random.RandomState(seed).permutation(len(self))
else:
indices = np.arange(len(self))
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
return (
self.src_sizes[index] <= max_source_positions
and (self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
)
def _get_max_positions(self, max_positions):
if max_positions is None:
return self.max_source_positions, self.max_target_positions
assert len(max_positions) == 2
max_src_pos, max_tgt_pos = max_positions
return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)
def collater_isolated(samples, seq_len_multiple, left_pad_source, left_pad_target):
"""Merge a list of samples to form a mini-batch."""
return collate(
samples,
pad_idx=1,
eos_idx=2,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
bsz_mult=8,
seq_len_multiple=seq_len_multiple,
)
def get_dummy_batch_isolated(max_tokens_per_batch, max_positions, seq_len_multiple):
'''Creates a dummy batch'''
max_source_positions, max_target_positions = max_positions[0], max_positions[1]
src_len, tgt_len = max_source_positions, max_target_positions
n_seq_per_batch_based_on_longest_seq = max_tokens_per_batch // max(src_len, tgt_len)
nspecial = 3
ntok_alloc = 33712
eos_id = 2
dummy_seq_src = torch.Tensor(src_len).uniform_(nspecial + 1, ntok_alloc).long()
dummy_seq_src[-1] = eos_id
dummy_seq_tgt = torch.Tensor(tgt_len).uniform_(nspecial + 1, ntok_alloc).long()
dummy_seq_tgt[-1] = eos_id
return collater_isolated([
{
'id': i,
'source': dummy_seq_src,
'target': dummy_seq_tgt
}
for i in range(n_seq_per_batch_based_on_longest_seq)
],
seq_len_multiple,
left_pad_source=True,
left_pad_target=False,
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment