Commit e8fe6b71 authored by thomwolf's avatar thomwolf
Browse files

adapting transfo tokenizer to transposed inputs

parent 884ca81d
...@@ -356,7 +356,10 @@ class LMOrderedIterator(object): ...@@ -356,7 +356,10 @@ class LMOrderedIterator(object):
data = self.data[beg_idx:end_idx] data = self.data[beg_idx:end_idx]
target = self.data[i+1:i+1+seq_len] target = self.data[i+1:i+1+seq_len]
return data, target, seq_len data_out = data.transpose(0, 1).contiguous().to(self.device)
target_out = target.transpose(0, 1).contiguous().to(self.device)
return data_out, target_out, seq_len
def get_fixlen_iter(self, start=0): def get_fixlen_iter(self, start=0):
for i in range(start, self.data.size(0) - 1, self.bptt): for i in range(start, self.data.size(0) - 1, self.bptt):
...@@ -440,10 +443,10 @@ class LMShuffledIterator(object): ...@@ -440,10 +443,10 @@ class LMShuffledIterator(object):
if not valid_batch: if not valid_batch:
return return
data = data.to(self.device) data_out = data.transpose(0, 1).contiguous().to(self.device)
target = target.to(self.device) target_out = target.transpose(0, 1).contiguous().to(self.device)
yield data, target, self.bptt yield data_out, target_out, self.bptt
n_retain = min(data.size(0), self.ext_len) n_retain = min(data.size(0), self.ext_len)
if n_retain > 0: if n_retain > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment