Commit cfd2a3a0 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

core changes to support latte collab

parent fbe8ce65
......@@ -42,11 +42,14 @@ class AdaptiveLoss(FairseqCriterion):
adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample['net_input'])
target = model.get_targets(sample, net_output).view(-1)
orig_target = model.get_targets(sample, net_output)
bsz = target.size(0)
nsentences = orig_target.size(0)
orig_target = orig_target.view(-1)
logits, target = adaptive_softmax(net_output[0], target)
bsz = orig_target.size(0)
logits, target = adaptive_softmax(net_output[0], orig_target)
assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_()
......@@ -57,11 +60,13 @@ class AdaptiveLoss(FairseqCriterion):
loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel()
sample_size = sample['target'].size(0) if self.args.sentence_avg else ntokens
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......
......@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
......
......@@ -199,3 +199,20 @@ class Dictionary(object):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length):
self.__class__ = type(dict.__class__.__name__,
(self.__class__, dict.__class__), {})
self.__dict__ = dict.__dict__
self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length)
def __len__(self):
return self.length
def __getitem__(self, i):
if i < self.length:
return self.wrapped_dict[i]
return self.wrapped_dict.unk()
......@@ -61,6 +61,7 @@ def collate(
'src_lengths': src_lengths,
},
'target': target,
'nsentences': samples[0]['source'].size(0),
}
if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens
......
......@@ -9,27 +9,39 @@ import numpy as np
import torch
from . import data_utils, FairseqDataset
from typing import List
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
def merge(key, is_list=False):
if is_list:
res = []
for i in range(len(samples[0][key])):
res.append(data_utils.collate_tokens(
[s[key][i] for s in samples], pad_idx, eos_idx, left_pad=False,
))
return res
else:
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
)
is_target_list = isinstance(samples[0]['target'], list)
return {
'id': torch.LongTensor([s['id'] for s in samples]),
'ntokens': sum(len(s['target']) for s in samples),
'ntokens': sum(len(s['source']) for s in samples),
'net_input': {
'src_tokens': merge('source'),
'src_lengths': torch.LongTensor([
s['source'].numel() for s in samples
]),
},
'target': merge('target'),
'target': merge('target', is_target_list),
'nsentences': samples[0]['source'].size(0),
}
......@@ -45,19 +57,75 @@ class MonolingualDataset(FairseqDataset):
Default: ``True``
"""
def __init__(self, dataset, sizes, vocab, shuffle=True):
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
targets=None):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
self.vocab = src_vocab
self.tgt_vocab = tgt_vocab
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle
assert targets is None or all(
t in {'self', 'future', 'past'} for t in targets), "targets must be none or one of 'self', 'future', 'past'"
if targets is not None and len(targets) == 0:
targets = None
self.targets = targets
def __getitem__(self, index):
source, target = self.dataset[index]
source, future_target, past_target = self.dataset[index]
source, target = self._make_source_target(source, future_target, past_target)
return {'id': index, 'source': source, 'target': target}
def __len__(self):
return len(self.dataset)
def _make_source_target(self, source, future_target, past_target):
if self.targets is not None:
target = []
if self.add_eos_for_other_targets and (('self' in self.targets) or ('past' in self.targets)) \
and source[-1] != self.vocab.eos():
# append eos at the end of source
source = torch.cat([source, source.new([self.vocab.eos()])])
if 'future' in self.targets:
future_target = torch.cat([future_target, future_target.new([self.vocab.pad()])])
if 'past' in self.targets:
# first token is before the start of sentence which is only used in "none" break mode when
# add_eos_for_other_targets is False
past_target = torch.cat([past_target.new([self.vocab.pad()]), past_target[1:], source[-2, None]])
for t in self.targets:
if t == 'self':
target.append(source)
elif t == 'future':
target.append(future_target)
elif t == 'past':
target.append(past_target)
else:
raise Exception('invalid target ' + t)
if len(target) == 1:
target = target[0]
else:
target = future_target
return source, self._filter_vocab(target)
def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab):
def _filter(target):
mask = target.ge(len(self.tgt_vocab))
if mask.any():
target[mask] = self.tgt_vocab.unk()
return target
if isinstance(target, list):
return [_filter(t) for t in target]
return _filter(target)
return target
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
......@@ -86,8 +154,10 @@ class MonolingualDataset(FairseqDataset):
if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len
target = self.vocab.dummy_sentence(tgt_len + 1)
source, target = target[:-1], target[1:]
target = self.vocab.dummy_sentence(tgt_len + 2)
source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target)
return self.collater([
{'id': i, 'source': source, 'target': target}
for i in range(bsz)
......
......@@ -29,11 +29,13 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets
"""
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
def __init__(self, tokens, sizes, block_size, pad, eos, break_mode=None, include_targets=False):
super().__init__()
self.tokens = tokens
self.total_size = len(tokens)
self.pad = pad
self.eos = eos
self.include_targets = include_targets
self.slice_indices = []
......@@ -81,12 +83,18 @@ class TokenBlockDataset(torch.utils.data.Dataset):
if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
# past target is rotated to the left by 2 (padded if its first)
if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
source = np.concatenate([[self.eos], self.tokens[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.tokens[0:e - 2]])
else:
source = self.tokens[s - 1:e - 1]
if s == 1:
past_target = np.concatenate([[self.eos], self.tokens[0:e - 2]])
else:
past_target = self.tokens[s - 2:e - 2]
return torch.LongTensor(source), item
return torch.LongTensor(source), item, torch.LongTensor(past_target)
return item
def __len__(self):
......
......@@ -65,6 +65,9 @@ class BaseFairseqModel(nn.Module):
def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code."""
self.upgrade_state_dict_named(state_dict, '')
def upgrade_state_dict_named(self, state_dict, name):
assert state_dict is not None
def do_upgrade(m, prefix):
......@@ -79,7 +82,7 @@ class BaseFairseqModel(nn.Module):
c.upgrade_state_dict(state_dict)
do_upgrade(c, name)
do_upgrade(self, '')
do_upgrade(self, name)
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
......@@ -196,3 +199,7 @@ class FairseqLanguageModel(BaseFairseqModel):
def max_positions(self):
"""Maximum length supported by the model."""
return self.decoder.max_positions()
@property
def supported_targets(self):
return {'future'}
......@@ -213,7 +213,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
return TransformerLanguageModel(decoder)
......@@ -442,6 +442,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x = x.transpose(0, 1)
attn = None
inner_states = [x]
# decoder layers
for layer in self.layers:
x, attn = layer(
......@@ -449,7 +451,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_out['encoder_out'] if encoder_out is not None else None,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
)
inner_states.append(x)
if self.normalize:
x = self.layer_norm(x)
......@@ -467,7 +471,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else:
x = F.linear(x, self.embed_out)
return x, attn
return x, {'attn': attn, 'inner_states': inner_states}
def max_positions(self):
"""Maximum output length supported by the decoder."""
......@@ -475,6 +479,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
......@@ -615,7 +627,8 @@ class TransformerDecoderLayer(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, self_attn_mask=None,
self_attn_padding_mask=None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
......@@ -631,9 +644,10 @@ class TransformerDecoderLayer(nn.Module):
query=x,
key=x,
value=x,
mask_future_timesteps=True,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
......@@ -728,7 +742,6 @@ def base_lm_architecture(args):
# The model training is not stable without this
args.decoder_normalize_before = True
@register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
......@@ -740,7 +753,7 @@ def transformer_lm_big(args):
@register_model_architecture('transformer_lm', 'transformer_lm_wiki103')
def transformer_lm_wiki103(args):
args.dropout = getattr(args, 'dropout', 0.3)
base_lm_architecture(args)
transformer_lm_big(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gbw')
......
......@@ -18,23 +18,31 @@ class MultiheadAttention(nn.Module):
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self._mask = None
self.scaling = self.head_dim ** -0.5
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
def reset_parameters(self):
......@@ -43,15 +51,18 @@ class MultiheadAttention(nn.Module):
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False):
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False, attn_mask=None):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
query, key and value. Timesteps can be masked by supplying a T x T mask in the
`attn_mask` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
......@@ -103,24 +114,40 @@ class MultiheadAttention(nn.Module):
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if self.add_zero_attn:
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
# only apply masking at training time (when incremental state is None)
if mask_future_timesteps and incremental_state is None:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if attn_mask is not None:
attn_weights += attn_mask.unsqueeze(0)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
......@@ -129,6 +156,7 @@ class MultiheadAttention(nn.Module):
float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
......@@ -156,10 +184,10 @@ class MultiheadAttention(nn.Module):
return self._in_proj(query, end=self.embed_dim)
def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim)
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
def in_proj_v(self, value):
return self._in_proj(value, start=2*self.embed_dim)
return self._in_proj(value, start=2 * self.embed_dim)
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
......@@ -169,14 +197,6 @@ class MultiheadAttention(nn.Module):
bias = bias[start:end]
return F.linear(input, weight, bias)
def buffered_mask(self, tensor):
dim = tensor.size(-1)
if self._mask is None:
self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._mask.size(0) < dim:
self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
......
......@@ -106,8 +106,9 @@ class FP16Optimizer(optim.FairseqOptimizer):
for p in self.params:
if not p.requires_grad:
continue
numel = p.grad.data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(p.grad.data.view(-1))
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
numel = grad_data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
offset += numel
# correct for dynamic loss scaler
......
......@@ -507,7 +507,11 @@ class SequenceGenerator(object):
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
if type(attn) is dict:
attn = attn['attn']
if attn is not None:
if type(attn) is dict:
attn = attn['attn']
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
return probs, attn
......@@ -13,7 +13,7 @@ from torch.utils.data import ConcatDataset
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset,
MonolingualDataset, TokenBlockDataset, TruncatedDictionary
)
from . import FairseqTask, register_task
......@@ -25,7 +25,14 @@ class LanguageModelingTask(FairseqTask):
Train a language model.
Args:
dictionary (Dictionary): the dictionary for the language model
dictionary (Dictionary): the dictionary for the input of the language model
output_dictionary (Dictionary): the dictionary for the output of the language model.
In most cases it will be the same as dictionary, but could possibly be a more limited
version of the dictionary (if --output-dictionary-size is used).
targets (List[str]): list of the target types that the language model should predict.
Can be one of "self", "future", and "past". Defaults to "future".
.. note::
......@@ -55,10 +62,23 @@ class LanguageModelingTask(FairseqTask):
help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
def __init__(self, args, dictionary):
parser.add_argument('--output-dictionary-size', default=-1, type=int,
help='limit the size of output dictionary')
parser.add_argument('--self-target', action='store_true',
help='include self target')
parser.add_argument('--future-target', action='store_true',
help='include future target')
parser.add_argument('--past-target', action='store_true',
help='include past target')
def __init__(self, args, dictionary, output_dictionary, targets=None):
super().__init__(args)
self.dictionary = dictionary
self.output_dictionary = output_dictionary
if targets is None:
targets = ['future']
self.targets = targets
@classmethod
def setup_task(cls, args, **kwargs):
......@@ -69,7 +89,36 @@ class LanguageModelingTask(FairseqTask):
"""
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
output_dictionary = dictionary
if args.output_dictionary_size >= 0:
output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size)
# upgrade old checkpoints
if hasattr(args, 'exclude_self_target'):
args.self_target = not args.exclude_self_target
targets = []
if args.self_target:
targets.append('self')
if args.future_target:
targets.append('future')
if args.past_target:
targets.append('past')
if len(targets) == 0:
# standard language modeling
targets = ['future']
return cls(args, dictionary, output_dictionary, targets=targets)
def build_model(self, args):
model = super().build_model(args)
for target in self.targets:
if target not in model.supported_targets:
raise ValueError('Unsupported language modeling target: {}'.format(target))
return model
def load_dataset(self, split, combine=False):
"""Load a given dataset split.
......@@ -98,8 +147,8 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True
tokens, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True,
))
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
......@@ -114,10 +163,16 @@ class LanguageModelingTask(FairseqTask):
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False)
add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'
self.datasets[split] = MonolingualDataset(
dataset, sizes, self.dictionary, self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=False,
targets=self.targets,
)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.dictionary
return self.output_dictionary
......@@ -40,7 +40,7 @@ def mock_dict():
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
epoch_itr = data.EpochBatchIterator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment