run_transfo_xl.py 6.19 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model evaluation script.
    Adapted from https://github.com/kimiyoung/transformer-xl.
    In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py

    This script with default values evaluates a pretrained Transformer-XL on WikiText 103
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
22

thomwolf's avatar
thomwolf committed
23
24
25
26

import argparse
import logging
import math
Aymeric Augustin's avatar
Aymeric Augustin committed
27
import time
thomwolf's avatar
thomwolf committed
28
29
30

import torch

Aymeric Augustin's avatar
Aymeric Augustin committed
31
32
from transformers import TransfoXLCorpus, TransfoXLLMHeadModel, TransfoXLTokenizer

thomwolf's avatar
thomwolf committed
33

34
35
36
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
thomwolf's avatar
thomwolf committed
37
38
logger = logging.getLogger(__name__)

39

thomwolf's avatar
thomwolf committed
40
def main():
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
    parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
    parser.add_argument(
        "--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
    )
    parser.add_argument("--batch_size", type=int, default=10, help="batch size")
    parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
    parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
    parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
    parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
    parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
    parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
    parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
    parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
thomwolf's avatar
thomwolf committed
57
    args = parser.parse_args()
58
    assert args.ext_len >= 0, "extended context length must be non-negative"
thomwolf's avatar
thomwolf committed
59

60
61
62
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
63

64
65
66
67
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
68
69
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    logger.info("device: {}".format(device))
thomwolf's avatar
thomwolf committed
70
71
72
73
74
75

    # Load a pre-processed dataset
    # You can also build the corpus yourself using TransfoXLCorpus methods
    # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
    # and tokenizing the dataset
    # The pre-processed corpus is a convertion (using the conversion script )
76
    tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name)
thomwolf's avatar
thomwolf committed
77
78
79
    corpus = TransfoXLCorpus.from_pretrained(args.model_name)
    ntokens = len(corpus.vocab)

80
81
    va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
    te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
thomwolf's avatar
thomwolf committed
82
83

    # Load a pre-trained model
thomwolf's avatar
thomwolf committed
84
    model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
thomwolf's avatar
thomwolf committed
85
86
    model = model.to(device)

87
88
89
90
91
    logger.info(
        "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
            args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
        )
    )
thomwolf's avatar
thomwolf committed
92
93
94
95
96
97
98
99
100
101
102
103
104

    model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    if args.clamp_len > 0:
        model.clamp_len = args.clamp_len
    if args.same_length:
        model.same_length = True

    ###############################################################################
    # Evaluation code
    ###############################################################################
    def evaluate(eval_iter):
        # Turn on evaluation mode which disables dropout.
        model.eval()
105
        total_len, total_loss = 0, 0.0
thomwolf's avatar
thomwolf committed
106
107
        start_time = time.time()
        with torch.no_grad():
108
            mems = None
thomwolf's avatar
thomwolf committed
109
            for idx, (data, target, seq_len) in enumerate(eval_iter):
110
                ret = model(data, lm_labels=target, mems=mems)
thomwolf's avatar
thomwolf committed
111
                loss, _, mems = ret
thomwolf's avatar
thomwolf committed
112
113
114
115
                loss = loss.mean()
                total_loss += seq_len * loss.item()
                total_len += seq_len
            total_time = time.time() - start_time
116
        logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
thomwolf's avatar
thomwolf committed
117
118
119
        return total_loss / total_len

    # Run on test data.
120
    if args.split == "all":
thomwolf's avatar
thomwolf committed
121
122
        test_loss = evaluate(te_iter)
        valid_loss = evaluate(va_iter)
123
    elif args.split == "valid":
thomwolf's avatar
thomwolf committed
124
125
        valid_loss = evaluate(va_iter)
        test_loss = None
126
    elif args.split == "test":
thomwolf's avatar
thomwolf committed
127
128
129
130
        test_loss = evaluate(te_iter)
        valid_loss = None

    def format_log(loss, split):
131
        log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
thomwolf's avatar
thomwolf committed
132
133
        return log_str

134
    log_str = ""
thomwolf's avatar
thomwolf committed
135
    if valid_loss is not None:
136
        log_str += format_log(valid_loss, "valid")
thomwolf's avatar
thomwolf committed
137
    if test_loss is not None:
138
        log_str += format_log(test_loss, "test")
thomwolf's avatar
thomwolf committed
139

140
    logger.info("=" * 100)
thomwolf's avatar
thomwolf committed
141
    logger.info(log_str)
142
143
    logger.info("=" * 100)

thomwolf's avatar
thomwolf committed
144

145
if __name__ == "__main__":
thomwolf's avatar
thomwolf committed
146
    main()