truncate_accuracy_log.py 10.3 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""
Tool to truncate the mlperf_log_accuracy.json
"""

from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import hashlib
import logging
import os
import re
import sys
import shutil


logging.basicConfig(level=logging.INFO)
log = logging.getLogger("main")

MAX_ACCURACY_LOG_SIZE = 10 * 1024
VIEWABLE_SIZE = 4096

HELP_TEXT = """
You can run this tool in 2 ways:

1. pick an existing submission directory and create a brand new submission tree with the trucated
    mlperf_log_accuracy.json files. The original submission directory is not modified.

    python tools/submission/truncate_accuracy_log.py --input ORIGINAL_SUBMISSION_DIRECTORY --submitter MY_ORG \\
        --output NEW_SUBMISSION_DIRECTORY

2. pick a existing submission directory and a backup location for files that are going to be modified.
    The tool will copy files that are modified into the backup directory and than modify the existing
    submission directory.

    python tools/submission/truncate_accuracy_log.py --input ROOT_OF_SUBMISSION_DIRECTORY --submitter MY_ORG \\
        --backup MY_SUPER_SAFE_STORAGE
"""


def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser(
        description="Truncate mlperf_log_accuracy.json files.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=HELP_TEXT,
    )
    parser.add_argument(
        "--input",
        required=True,
        help="orignal submission directory")
    parser.add_argument("--output", help="new submission directory")
    parser.add_argument(
        "--submitter",
        required=True,
        help="filter to submitter")
    parser.add_argument(
        "--backup",
        help="directory to store the original accuacy log")
    parser.add_argument(
        "--scenarios-to-skip",
        help="Delimited list input of scenarios to skip. i.e. if you only have Offline results, pass in 'Server'",
        type=str,
    )

    args = parser.parse_args()
    if not args.output and not args.backup:
        parser.print_help()
        sys.exit(1)

    return args


def list_dir(*path):
    path = os.path.join(*path)
    return [f for f in os.listdir(
        path) if os.path.isdir(os.path.join(path, f))]


def list_files(*path):
    path = os.path.join(*path)
    return [f for f in os.listdir(
        path) if os.path.isfile(os.path.join(path, f))]


def split_path(m):
    return m.replace("\\", "/").split("/")


def get_hash(fname):
    """Return hash for file."""
    m = hashlib.sha256()
    with open(fname, "rb") as f:
        for byte_block in iter(lambda: f.read(4096), b""):
            m.update(byte_block)
    return m.hexdigest()


def truncate_file(fname):
    """Truncate file to 4K from start and 4K from end."""
    size = os.stat(fname).st_size
    if size < VIEWABLE_SIZE:
        return
    with open(fname, "r") as src:
        start = src.read(VIEWABLE_SIZE)
        src.seek(size - VIEWABLE_SIZE, 0)
        end = src.read(VIEWABLE_SIZE)
    with open(fname, "w") as dst:
        dst.write(start)
        dst.write("\n\n...\n\n")
        dst.write(end)


def copy_submission_dir(src, dst, filter_submitter):
    for division in list_dir(src):
        if division not in ["closed", "open", "network"]:
            continue
        for submitter in list_dir(os.path.join(src, division)):
            if filter_submitter and submitter != filter_submitter:
                continue
            shutil.copytree(
                os.path.join(src, division, submitter),
                os.path.join(dst, division, submitter),
            )


def truncate_results_dir(filter_submitter, backup, scenarios_to_skip):
    """Walk result dir and
    write a hash of mlperf_log_accuracy.json to accuracy.txt
    copy mlperf_log_accuracy.json to a backup location
    truncate mlperf_log_accuracy.
    """
    for division in list_dir("."):
        # we are looking at ./$division, ie ./closed
        if division not in ["closed", "open", "network"]:
            continue

        for submitter in list_dir(division):
            # we are looking at ./$division/$submitter, ie ./closed/mlperf_org
            if filter_submitter and submitter != filter_submitter:
                continue

            # process results
            for directory in ["results", "compliance"]:

                log_path = os.path.join(division, submitter, directory)
                if not os.path.exists(log_path):
                    log.error("no submission in %s", log_path)
                    continue

                for system_desc in list_dir(log_path):
                    for model in list_dir(log_path, system_desc):
                        for scenario in list_dir(log_path, system_desc, model):
                            if scenario in scenarios_to_skip:
                                continue
                            for test in list_dir(
                                log_path, system_desc, model, scenario
                            ):

                                name = os.path.join(
                                    log_path, system_desc, model, scenario
                                )
                                if directory == "compliance":
                                    name = os.path.join(
                                        log_path, system_desc, model, scenario, test
                                    )

                                hash_val = None
                                acc_path = os.path.join(name, "accuracy")
                                acc_log = os.path.join(
                                    acc_path, "mlperf_log_accuracy.json"
                                )
                                acc_txt = os.path.join(
                                    acc_path, "accuracy.txt")

                                # only TEST01 has an accuracy log
                                if directory == "compliance" and test != "TEST01":
                                    continue
                                if not os.path.exists(acc_log):
                                    log.error("%s missing", acc_log)
                                    continue
                                if (
                                    not os.path.exists(acc_txt)
                                    and directory == "compliance"
                                ):
                                    # compliance test directory will not have
                                    # an accuracy.txt file by default
                                    log.info(
                                        "no accuracy.txt in compliance directory %s",
                                        acc_path,
                                    )
                                else:
                                    if not os.path.exists(acc_txt):
                                        log.error(
                                            "%s missing, generate to continue", acc_txt
                                        )
                                        continue
                                    with open(acc_txt, "r", encoding="utf-8") as f:
                                        for line in f:
                                            m = re.match(
                                                r"^hash=([\w\d]+)$", line)
                                            if m:
                                                hash_val = m.group(1)
                                                break
                                size = os.stat(acc_log).st_size
                                if hash_val and size < MAX_ACCURACY_LOG_SIZE:
                                    log.info(
                                        "%s already has hash and size seems truncated",
                                        acc_path,
                                    )
                                    continue

                                if backup:
                                    backup_dir = os.path.join(
                                        backup, name, "accuracy")
                                    os.makedirs(backup_dir, exist_ok=True)
                                    dst = os.path.join(
                                        backup,
                                        name,
                                        "accuracy",
                                        "mlperf_log_accuracy.json",
                                    )
                                    if os.path.exists(dst):
                                        log.error(
                                            "not processing %s because %s already exist",
                                            acc_log,
                                            dst,
                                        )
                                        continue
                                    shutil.copy(acc_log, dst)

                                # get to work
                                hash_val = get_hash(acc_log)
                                with open(acc_txt, "a", encoding="utf-8") as f:
                                    f.write("\nhash={0}\n".format(hash_val))
                                truncate_file(acc_log)
                                log.info("%s truncated", acc_log)

                                # No need to iterate on compliance test
                                # subdirectories in the results folder
                                if directory == "results":
                                    break


def main():
    args = get_args()

    src_dir = args.input
    if args.output:
        if os.path.exists(args.output):
            print("output directory already exists")
            sys.exit(1)
        os.makedirs(args.output)
        copy_submission_dir(args.input, args.output, args.submitter)
        src_dir = args.output

    os.chdir(src_dir)

    if args.scenarios_to_skip:
        scenarios_to_skip = [
            scenario for scenario in args.scenarios_to_skip.split(",")]
    else:
        scenarios_to_skip = []

    # truncate results directory
    truncate_results_dir(args.submitter, args.backup, scenarios_to_skip)

    backup_location = args.output or args.backup
    log.info(
        "Make sure you keep a backup of %s in case mlperf wants to see the original accuracy logs",
        backup_location,
    )

    return 0


if __name__ == "__main__":
    sys.exit(main())