"pytorch_transformers/tokenization_openai.py" did not exist on "850da1cc36f95175219420365ac3fb95b483ce8d"
transfo_xl_eval.py 5.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model evaluation script.
    Adapted from https://github.com/kimiyoung/transformer-xl.
    In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
19
20

    This script with default values evaluates a pretrained Transformer-XL on WikiText 103
21
"""
thomwolf's avatar
thomwolf committed
22
23
from __future__ import absolute_import, division, print_function, unicode_literals

24
import argparse
thomwolf's avatar
thomwolf committed
25
import logging
26
27
28
29
30
31
32
import time
import math

import torch

from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus

thomwolf's avatar
thomwolf committed
33
34
35
36
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)
37
38
39
40
41


parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
                    help='pretrained model name')
thomwolf's avatar
thomwolf committed
42
parser.add_argument('--split', type=str, default='test',
43
44
45
46
                    choices=['all', 'valid', 'test'],
                    help='which split to evaluate')
parser.add_argument('--batch_size', type=int, default=10,
                    help='batch size')
thomwolf's avatar
thomwolf committed
47
parser.add_argument('--tgt_len', type=int, default=128,
48
49
50
                    help='number of tokens to predict')
parser.add_argument('--ext_len', type=int, default=0,
                    help='length of the extended context')
thomwolf's avatar
thomwolf committed
51
parser.add_argument('--mem_len', type=int, default=1600,
52
                    help='length of the retained previous heads')
thomwolf's avatar
thomwolf committed
53
parser.add_argument('--clamp_len', type=int, default=1000,
54
55
56
57
58
59
60
61
62
63
64
65
66
67
                    help='max positional embedding index')
parser.add_argument('--cuda', action='store_true',
                    help='use CUDA')
parser.add_argument('--work_dir', type=str, required=True,
                    help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
                    help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
                    help='set same length attention with masking')
args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative'

device = torch.device("cuda" if args.cuda else "cpu")

68
69
70
71
72
# Load a pre-processed dataset
# You can also build the corpus yourself using TransfoXLCorpus methods
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
# and tokenizing the dataset
# The pre-processed corpus is a convertion (using the conversion script )
73
74
75
76
77
78
79
80
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
ntokens = len(corpus.vocab)

va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
    device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
    device=device, ext_len=args.ext_len)

81
# Load a pre-trained model
82
83
84
model = TransfoXLModel.from_pretrained(args.model_name)
model = model.to(device)

thomwolf's avatar
thomwolf committed
85
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
       args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))

model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0:
    model.clamp_len = args.clamp_len
if args.same_length:
    model.same_length = True

###############################################################################
# Evaluation code
###############################################################################
def evaluate(eval_iter):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_len, total_loss = 0, 0.
    start_time = time.time()
    with torch.no_grad():
        mems = tuple()
        for idx, (data, target, seq_len) in enumerate(eval_iter):
            ret = model(data, target, *mems)
thomwolf's avatar
thomwolf committed
106
            loss, mems = ret
107
108
109
110
            loss = loss.mean()
            total_loss += seq_len * loss.item()
            total_len += seq_len
        total_time = time.time() - start_time
thomwolf's avatar
thomwolf committed
111
    logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            total_time, 1000 * total_time / (idx+1)))
    return total_loss / total_len

# Run on test data.
if args.split == 'all':
    test_loss = evaluate(te_iter)
    valid_loss = evaluate(va_iter)
elif args.split == 'valid':
    valid_loss = evaluate(va_iter)
    test_loss = None
elif args.split == 'test':
    test_loss = evaluate(te_iter)
    valid_loss = None

def format_log(loss, split):
thomwolf's avatar
thomwolf committed
127
128
    log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
        split, loss, math.exp(loss))
129
130
131
132
133
134
135
136
    return log_str

log_str = ''
if valid_loss is not None:
    log_str += format_log(valid_loss, 'valid')
if test_loss is not None:
    log_str += format_log(test_loss, 'test')

thomwolf's avatar
thomwolf committed
137
138
139
logger.info('=' * 100)
logger.info(log_str)
logger.info('=' * 100)