tb-remove-events-by-group.py 2.53 KB
Newer Older
hepj987's avatar
hepj987 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# this script removes events from tensorboard log files by group names
# it does the removal in place (so make back ups!)
#
# example:
#
#  find . -name "*.tfevents*" -exec tb-remove-events-by-group.py {} "batch-size" \;
#
# which wold match any of "batch-size/batch-size", "batch-size/batch-size vs samples", etc.
#
# more than one group can be removed - use `;` as a separator:
#
#  tb-remove-events-by-group.py events.out.tfevents.1 "batch-size;grad-norm"
#
# this script is derived from https://stackoverflow.com/a/60080531/9201239
#
# Important: this script requires CUDA environment.

from pathlib import Path
import os
import re
import shlex
import sys

# avoid using the GPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.core.util.event_pb2 import Event


def is_tag_matching_group(tag, groups_to_remove):
    for group in groups_to_remove:
        if tag.startswith(group):
            return True
    return False


def remove_events(input_file, groups_to_remove):
    new_file = input_file + ".new"
    # Make a record writer
    with tf.io.TFRecordWriter(new_file) as writer:
        # Iterate event records
        for rec in tf.data.TFRecordDataset([input_file]):
            # Read event
            ev = Event()
            ev.MergeFromString(rec.numpy())
            # Check if it is a summary event
            if ev.summary:
                orig_values = [v for v in ev.summary.value]
                filtered_values = [v for v in orig_values if not is_tag_matching_group(v.tag, groups_to_remove)]
                #print(f"filtered_values={len(filtered_values)}, orig_values={len(orig_values)}")
                if len(filtered_values) != len(orig_values):
                    # for v in orig_values:
                    #     print(v)
                    del ev.summary.value[:]
                    ev.summary.value.extend(filtered_values)
            writer.write(ev.SerializeToString())
    os.rename(new_file, input_file)

def remove_events_dir(input_file, groups_to_remove):
    # Write removed events
    remove_events(input_file, groups_to_remove)

if __name__ == '__main__':
    if len(sys.argv) != 3:
        print(f'{sys.argv[0]} <input file> <tags to remove>',
              file=sys.stderr)
        sys.exit(1)
    input_file, groups_to_remove = sys.argv[1:]
    print(input_file, shlex.quote(groups_to_remove))
    groups_to_remove = groups_to_remove.split(';')
    remove_events_dir(input_file, groups_to_remove)