Commit ed47cb6c authored by thomwolf's avatar thomwolf
Browse files

fixing transfo eval script

parent 97392643
...@@ -111,7 +111,7 @@ def evaluate(eval_iter): ...@@ -111,7 +111,7 @@ def evaluate(eval_iter):
mems = tuple() mems = tuple()
for idx, (data, target, seq_len) in enumerate(eval_iter): for idx, (data, target, seq_len) in enumerate(eval_iter):
ret = model(data, target, *mems) ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:] loss, mems = ret
loss = loss.mean() loss = loss.mean()
total_loss += seq_len * loss.item() total_loss += seq_len * loss.item()
total_len += seq_len total_len += seq_len
......
...@@ -1215,7 +1215,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1215,7 +1215,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
# So, have to initialize size(0) mems inside the model forward. # So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece # Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together. # them together.
if not mems: mems = self.init_mems(data) if not mems:
mems = self.init_mems(data)
hidden, new_mems = self._forward(data, mems=mems) hidden, new_mems = self._forward(data, mems=mems)
if target is None: if target is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment