"examples/seq2seq/test_tatoeba_conversion.py" did not exist on "aacac8f708a9f139f4cf976e76f40be23ef68b57"
Commit e8fe6b71 authored by thomwolf's avatar thomwolf
Browse files

adapting transfo tokenizer to transposed inputs

parent 884ca81d
......@@ -356,7 +356,10 @@ class LMOrderedIterator(object):
data = self.data[beg_idx:end_idx]
target = self.data[i+1:i+1+seq_len]
return data, target, seq_len
data_out = data.transpose(0, 1).contiguous().to(self.device)
target_out = target.transpose(0, 1).contiguous().to(self.device)
return data_out, target_out, seq_len
def get_fixlen_iter(self, start=0):
for i in range(start, self.data.size(0) - 1, self.bptt):
......@@ -440,10 +443,10 @@ class LMShuffledIterator(object):
if not valid_batch:
return
data = data.to(self.device)
target = target.to(self.device)
data_out = data.transpose(0, 1).contiguous().to(self.device)
target_out = target.transpose(0, 1).contiguous().to(self.device)
yield data, target, self.bptt
yield data_out, target_out, self.bptt
n_retain = min(data.size(0), self.ext_len)
if n_retain > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment