#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# this script removes events from tensorboard log files by group names
# it does the removal in place (so make back ups!)
#
# example:
#
# find . -name "*.tfevents*" -exec tb-remove-events-by-group.py {} "batch-size" \;
#
# which wold match any of "batch-size/batch-size", "batch-size/batch-size vs samples", etc.
#
# more than one group can be removed - use `;` as a separator:
#
# tb-remove-events-by-group.py events.out.tfevents.1 "batch-size;grad-norm"
#
# this script is derived from https://stackoverflow.com/a/60080531/9201239
#
# Important: this script requires CUDA environment.
from pathlib import Path
import os
import re
import shlex
import sys
# avoid using the GPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.core.util.event_pb2 import Event
def is_tag_matching_group(tag, groups_to_remove):
for group in groups_to_remove:
if tag.startswith(group):
return True
return False
def remove_events(input_file, groups_to_remove):
new_file = input_file + ".new"
# Make a record writer
with tf.io.TFRecordWriter(new_file) as writer:
# Iterate event records
for rec in tf.data.TFRecordDataset([input_file]):
# Read event
ev = Event()
ev.MergeFromString(rec.numpy())
# Check if it is a summary event
if ev.summary:
orig_values = [v for v in ev.summary.value]
filtered_values = [v for v in orig_values if not is_tag_matching_group(v.tag, groups_to_remove)]
#print(f"filtered_values={len(filtered_values)}, orig_values={len(orig_values)}")
if len(filtered_values) != len(orig_values):
# for v in orig_values:
# print(v)
del ev.summary.value[:]
ev.summary.value.extend(filtered_values)
writer.write(ev.SerializeToString())
os.rename(new_file, input_file)
def remove_events_dir(input_file, groups_to_remove):
# Write removed events
remove_events(input_file, groups_to_remove)
if __name__ == '__main__':
if len(sys.argv) != 3:
print(f'{sys.argv[0]} ',
file=sys.stderr)
sys.exit(1)
input_file, groups_to_remove = sys.argv[1:]
print(input_file, shlex.quote(groups_to_remove))
groups_to_remove = groups_to_remove.split(';')
remove_events_dir(input_file, groups_to_remove)