squad_QSL.py 3.75 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# coding=utf-8
# Copyright 2021 Arm Limited and affiliates.
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import mlperf_loadgen as lg
from create_squad_data import read_squad_examples, convert_examples_to_features
from transformers import BertTokenizer
import os
import sys

sys.path.insert(0, os.getcwd())


# To support feature cache.

max_seq_length = 384
max_query_length = 64
doc_stride = 128


class SQuAD_v1_QSL:
    def __init__(
        self,
        total_count_override=None,
        perf_count_override=None,
        cache_path="eval_features.pickle",
    ):
        print("Constructing QSL...")
        eval_features = []
        # Load features if cached, convert from examples otherwise.
        if os.path.exists(cache_path):
            print("Loading cached features from '%s'..." % cache_path)
            with open(cache_path, "rb") as cache_file:
                eval_features = pickle.load(cache_file)
        else:
            print(
                "No cached features at '%s'... converting from examples..." % cache_path
            )

            print("Creating tokenizer...")
            vocab_file = os.environ.get("VOCAB_FILE")
            if not vocab_file:
                vocab_file = "build/data/bert_tf_v1_1_large_fp32_384_v2/vocab.txt"
            tokenizer = BertTokenizer(vocab_file)

            print("Reading examples...")
            dataset_file = os.environ.get("DATASET_FILE")
            if not dataset_file:
                dataset_file = "build/data/dev-v1.1.json"
            eval_examples = read_squad_examples(
                input_file=dataset_file,
                is_training=False,
                version_2_with_negative=False,
            )

            print("Converting examples to features...")

            def append_feature(feature):
                eval_features.append(feature)

            convert_examples_to_features(
                examples=eval_examples,
                tokenizer=tokenizer,
                max_seq_length=max_seq_length,
                doc_stride=doc_stride,
                max_query_length=max_query_length,
                is_training=False,
                output_fn=append_feature,
                verbose_logging=False,
            )

            print("Caching features at '%s'..." % cache_path)
            with open(cache_path, "wb") as cache_file:
                pickle.dump(eval_features, cache_file)

        self.eval_features = eval_features
        self.count = total_count_override or len(self.eval_features)
        self.perf_count = perf_count_override or self.count
        self.qsl = lg.ConstructQSL(
            self.count,
            self.perf_count,
            self.load_query_samples,
            self.unload_query_samples,
        )
        print("Finished constructing QSL.")

    def load_query_samples(self, sample_list):
        pass

    def unload_query_samples(self, sample_list):
        pass

    def get_features(self, sample_id):
        return self.eval_features[sample_id]

    def __del__(self):
        print("Finished destroying QSL.")


def get_squad_QSL(total_count_override=None, perf_count_override=None):
    return SQuAD_v1_QSL(total_count_override, perf_count_override)