Commit 6cd76995 authored by thomwolf's avatar thomwolf
Browse files

update transfo xl example

parent 1320e4ec
...@@ -28,7 +28,7 @@ import math ...@@ -28,7 +28,7 @@ import math
import torch import torch
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -79,7 +79,7 @@ def main(): ...@@ -79,7 +79,7 @@ def main():
device=device, ext_len=args.ext_len) device=device, ext_len=args.ext_len)
# Load a pre-trained model # Load a pre-trained model
model = TransfoXLModel.from_pretrained(args.model_name) model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
model = model.to(device) model = model.to(device)
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
...@@ -139,4 +139,4 @@ def main(): ...@@ -139,4 +139,4 @@ def main():
logger.info('=' * 100) logger.info('=' * 100)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
...@@ -169,7 +169,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -169,7 +169,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if i == 0: if i == 0:
if target is not None: if target is not None:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
else: else:
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment