run_transfo_xl.py 6.07 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model evaluation script.
    Adapted from https://github.com/kimiyoung/transformer-xl.
    In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py

    This script with default values evaluates a pretrained Transformer-XL on WikiText 103
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
22

thomwolf's avatar
thomwolf committed
23
24
25
26

import argparse
import logging
import math
Aymeric Augustin's avatar
Aymeric Augustin committed
27
import time
thomwolf's avatar
thomwolf committed
28
29
30

import torch

31
from transformers import TransfoXLCorpus, TransfoXLLMHeadModel
Aymeric Augustin's avatar
Aymeric Augustin committed
32

thomwolf's avatar
thomwolf committed
33

34
35
36
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
thomwolf's avatar
thomwolf committed
37
38
logger = logging.getLogger(__name__)

39

thomwolf's avatar
thomwolf committed
40
def main():
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
    parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
    parser.add_argument(
        "--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
    )
    parser.add_argument("--batch_size", type=int, default=10, help="batch size")
    parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
    parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
    parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
    parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
    parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
    parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
    parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
    parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
thomwolf's avatar
thomwolf committed
57
    args = parser.parse_args()
58
    assert args.ext_len >= 0, "extended context length must be non-negative"
thomwolf's avatar
thomwolf committed
59

60
61
62
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
63

64
65
66
67
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
68
69
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    logger.info("device: {}".format(device))
thomwolf's avatar
thomwolf committed
70
71
72
73
74
75
76
77

    # Load a pre-processed dataset
    # You can also build the corpus yourself using TransfoXLCorpus methods
    # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
    # and tokenizing the dataset
    # The pre-processed corpus is a convertion (using the conversion script )
    corpus = TransfoXLCorpus.from_pretrained(args.model_name)

78
79
    va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
    te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
thomwolf's avatar
thomwolf committed
80
81

    # Load a pre-trained model
thomwolf's avatar
thomwolf committed
82
    model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
thomwolf's avatar
thomwolf committed
83
84
    model = model.to(device)

85
86
87
88
89
    logger.info(
        "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
            args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
        )
    )
thomwolf's avatar
thomwolf committed
90
91
92
93
94
95
96
97
98
99
100
101
102

    model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    if args.clamp_len > 0:
        model.clamp_len = args.clamp_len
    if args.same_length:
        model.same_length = True

    ###############################################################################
    # Evaluation code
    ###############################################################################
    def evaluate(eval_iter):
        # Turn on evaluation mode which disables dropout.
        model.eval()
103
        total_len, total_loss = 0, 0.0
thomwolf's avatar
thomwolf committed
104
105
        start_time = time.time()
        with torch.no_grad():
106
            mems = None
thomwolf's avatar
thomwolf committed
107
            for idx, (data, target, seq_len) in enumerate(eval_iter):
108
                ret = model(data, lm_labels=target, mems=mems)
thomwolf's avatar
thomwolf committed
109
                loss, _, mems = ret
thomwolf's avatar
thomwolf committed
110
111
112
113
                loss = loss.mean()
                total_loss += seq_len * loss.item()
                total_len += seq_len
            total_time = time.time() - start_time
114
        logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
thomwolf's avatar
thomwolf committed
115
116
117
        return total_loss / total_len

    # Run on test data.
118
    if args.split == "all":
thomwolf's avatar
thomwolf committed
119
120
        test_loss = evaluate(te_iter)
        valid_loss = evaluate(va_iter)
121
    elif args.split == "valid":
thomwolf's avatar
thomwolf committed
122
123
        valid_loss = evaluate(va_iter)
        test_loss = None
124
    elif args.split == "test":
thomwolf's avatar
thomwolf committed
125
126
127
128
        test_loss = evaluate(te_iter)
        valid_loss = None

    def format_log(loss, split):
129
        log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
thomwolf's avatar
thomwolf committed
130
131
        return log_str

132
    log_str = ""
thomwolf's avatar
thomwolf committed
133
    if valid_loss is not None:
134
        log_str += format_log(valid_loss, "valid")
thomwolf's avatar
thomwolf committed
135
    if test_loss is not None:
136
        log_str += format_log(test_loss, "test")
thomwolf's avatar
thomwolf committed
137

138
    logger.info("=" * 100)
thomwolf's avatar
thomwolf committed
139
    logger.info(log_str)
140
141
    logger.info("=" * 100)

thomwolf's avatar
thomwolf committed
142

143
if __name__ == "__main__":
thomwolf's avatar
thomwolf committed
144
    main()