run_transfo_xl.py 6.06 KB
Newer Older
1
#!/usr/bin/env python
thomwolf's avatar
thomwolf committed
2
# coding=utf-8
thomwolf's avatar
thomwolf committed
3
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model evaluation script.
    Adapted from https://github.com/kimiyoung/transformer-xl.
    In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py

    This script with default values evaluates a pretrained Transformer-XL on WikiText 103
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
23

thomwolf's avatar
thomwolf committed
24
25
26
27

import argparse
import logging
import math
Aymeric Augustin's avatar
Aymeric Augustin committed
28
import time
thomwolf's avatar
thomwolf committed
29
30
31

import torch

32
from transformers import TransfoXLCorpus, TransfoXLLMHeadModel
Aymeric Augustin's avatar
Aymeric Augustin committed
33

thomwolf's avatar
thomwolf committed
34

35
36
37
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
thomwolf's avatar
thomwolf committed
38
39
logger = logging.getLogger(__name__)

40

thomwolf's avatar
thomwolf committed
41
def main():
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
    parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
    parser.add_argument(
        "--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
    )
    parser.add_argument("--batch_size", type=int, default=10, help="batch size")
    parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
    parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
    parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
    parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
    parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
    parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
    parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
    parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
thomwolf's avatar
thomwolf committed
58
    args = parser.parse_args()
59
    assert args.ext_len >= 0, "extended context length must be non-negative"
thomwolf's avatar
thomwolf committed
60

61
62
63
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
64

65
66
67
68
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
69
70
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    logger.info("device: {}".format(device))
thomwolf's avatar
thomwolf committed
71
72
73
74
75
76
77
78

    # Load a pre-processed dataset
    # You can also build the corpus yourself using TransfoXLCorpus methods
    # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
    # and tokenizing the dataset
    # The pre-processed corpus is a convertion (using the conversion script )
    corpus = TransfoXLCorpus.from_pretrained(args.model_name)

79
80
    va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
    te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
thomwolf's avatar
thomwolf committed
81
82

    # Load a pre-trained model
thomwolf's avatar
thomwolf committed
83
    model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
84
    model.to(device)
thomwolf's avatar
thomwolf committed
85

86
87
88
89
90
    logger.info(
        "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
            args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
        )
    )
thomwolf's avatar
thomwolf committed
91

92
    model.reset_memory_length(args.mem_len)
thomwolf's avatar
thomwolf committed
93
94
95
96
97
98
99
100
101
102
103
    if args.clamp_len > 0:
        model.clamp_len = args.clamp_len
    if args.same_length:
        model.same_length = True

    ###############################################################################
    # Evaluation code
    ###############################################################################
    def evaluate(eval_iter):
        # Turn on evaluation mode which disables dropout.
        model.eval()
104
        total_len, total_loss = 0, 0.0
thomwolf's avatar
thomwolf committed
105
106
        start_time = time.time()
        with torch.no_grad():
107
            mems = None
thomwolf's avatar
thomwolf committed
108
            for idx, (data, target, seq_len) in enumerate(eval_iter):
109
                ret = model(data, lm_labels=target, mems=mems)
thomwolf's avatar
thomwolf committed
110
                loss, _, mems = ret
thomwolf's avatar
thomwolf committed
111
112
113
114
                loss = loss.mean()
                total_loss += seq_len * loss.item()
                total_len += seq_len
            total_time = time.time() - start_time
115
        logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
thomwolf's avatar
thomwolf committed
116
117
118
        return total_loss / total_len

    # Run on test data.
119
    if args.split == "all":
thomwolf's avatar
thomwolf committed
120
121
        test_loss = evaluate(te_iter)
        valid_loss = evaluate(va_iter)
122
    elif args.split == "valid":
thomwolf's avatar
thomwolf committed
123
124
        valid_loss = evaluate(va_iter)
        test_loss = None
125
    elif args.split == "test":
thomwolf's avatar
thomwolf committed
126
127
128
129
        test_loss = evaluate(te_iter)
        valid_loss = None

    def format_log(loss, split):
130
        log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
thomwolf's avatar
thomwolf committed
131
132
        return log_str

133
    log_str = ""
thomwolf's avatar
thomwolf committed
134
    if valid_loss is not None:
135
        log_str += format_log(valid_loss, "valid")
thomwolf's avatar
thomwolf committed
136
    if test_loss is not None:
137
        log_str += format_log(test_loss, "test")
thomwolf's avatar
thomwolf committed
138

139
    logger.info("=" * 100)
thomwolf's avatar
thomwolf committed
140
    logger.info(log_str)
141
142
    logger.info("=" * 100)

thomwolf's avatar
thomwolf committed
143

144
if __name__ == "__main__":
thomwolf's avatar
thomwolf committed
145
    main()