"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "bfadb5ea5fbd7197d72d517da8885433ef190953"
Unverified Commit 68102441 authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Fix] Fix bugs in `analyze_logs` (#2184)

* Update analyze_logs.py

* fix lint

* fix minor bugs in plotting evaluation metric

* fix

* fix docs
parent fd013a46
...@@ -17,18 +17,20 @@ def cal_train_time(log_dicts, args): ...@@ -17,18 +17,20 @@ def cal_train_time(log_dicts, args):
all_times.append(log_dict[epoch]['time']) all_times.append(log_dict[epoch]['time'])
else: else:
all_times.append(log_dict[epoch]['time'][1:]) all_times.append(log_dict[epoch]['time'][1:])
all_times = np.array(all_times) if not all_times:
epoch_ave_time = all_times.mean(-1) raise KeyError(
'Please reduce the log interval in the config so that '
'interval is less than iterations of one epoch.')
epoch_ave_time = np.array(list(map(lambda x: np.mean(x), all_times)))
slowest_epoch = epoch_ave_time.argmax() slowest_epoch = epoch_ave_time.argmax()
fastest_epoch = epoch_ave_time.argmin() fastest_epoch = epoch_ave_time.argmin()
std_over_epoch = epoch_ave_time.std() std_over_epoch = epoch_ave_time.std()
print(f'slowest epoch {slowest_epoch + 1}, ' print(f'slowest epoch {slowest_epoch + 1}, '
f'average time is {epoch_ave_time[slowest_epoch]:.4f}') f'average time is {epoch_ave_time[slowest_epoch]:.4f} s/iter')
print(f'fastest epoch {fastest_epoch + 1}, ' print(f'fastest epoch {fastest_epoch + 1}, '
f'average time is {epoch_ave_time[fastest_epoch]:.4f}') f'average time is {epoch_ave_time[fastest_epoch]:.4f} s/iter')
print(f'time std over epochs is {std_over_epoch:.4f}') print(f'time std over epochs is {std_over_epoch:.4f}')
print(f'average iter time: {np.mean(all_times):.4f} s/iter') print(f'average iter time: {np.mean(epoch_ave_time):.4f} s/iter\n')
print()
def plot_curve(log_dicts, args): def plot_curve(log_dicts, args):
...@@ -50,56 +52,41 @@ def plot_curve(log_dicts, args): ...@@ -50,56 +52,41 @@ def plot_curve(log_dicts, args):
epochs = list(log_dict.keys()) epochs = list(log_dict.keys())
for j, metric in enumerate(metrics): for j, metric in enumerate(metrics):
print(f'plot curve of {args.json_logs[i]}, metric is {metric}') print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
if metric not in log_dict[epochs[args.interval - 1]]: if metric not in log_dict[epochs[int(args.eval_interval) - 1]]:
if args.eval:
raise KeyError(
f'{args.json_logs[i]} does not contain metric '
f'{metric}. Please check if "--no-validate" is '
'specified when you trained the model. Or check '
f'if the eval_interval {args.eval_interval} in args '
'is equal to the `eval_interval` during training.')
raise KeyError( raise KeyError(
f'{args.json_logs[i]} does not contain metric {metric}') f'{args.json_logs[i]} does not contain metric {metric}. '
'Please reduce the log interval in the config so that '
if args.mode == 'eval': 'interval is less than iterations of one epoch.')
if min(epochs) == args.interval:
x0 = args.interval if args.eval:
else: xs = []
# if current training is resumed from previous checkpoint
# we lost information in early epochs
# `xs` should start according to `min(epochs)`
if min(epochs) % args.interval == 0:
x0 = min(epochs)
else:
# find the first epoch that do eval
x0 = min(epochs) + args.interval - \
min(epochs) % args.interval
xs = np.arange(x0, max(epochs) + 1, args.interval)
ys = [] ys = []
for epoch in epochs[args.interval - 1::args.interval]: for epoch in epochs:
ys += log_dict[epoch][metric] ys += log_dict[epoch][metric]
if log_dict[epoch][metric]:
# if training is aborted before eval of the last epoch xs += [epoch]
# `xs` and `ys` will have different length and cause an error
# check if `ys[-1]` is empty here
if not log_dict[epoch][metric]:
xs = xs[:-1]
ax = plt.gca()
ax.set_xticks(xs)
plt.xlabel('epoch') plt.xlabel('epoch')
plt.plot(xs, ys, label=legend[i * num_metrics + j], marker='o') plt.plot(xs, ys, label=legend[i * num_metrics + j], marker='o')
else: else:
xs = [] xs = []
ys = [] ys = []
num_iters_per_epoch = \ for epoch in epochs:
log_dict[epochs[args.interval-1]]['iter'][-1] iters = log_dict[epoch]['step']
for epoch in epochs[args.interval - 1::args.interval]: xs.append(np.array(iters))
iters = log_dict[epoch]['iter']
if log_dict[epoch]['mode'][-1] == 'val':
iters = iters[:-1]
xs.append(
np.array(iters) + (epoch - 1) * num_iters_per_epoch)
ys.append(np.array(log_dict[epoch][metric][:len(iters)])) ys.append(np.array(log_dict[epoch][metric][:len(iters)]))
xs = np.concatenate(xs) xs = np.concatenate(xs)
ys = np.concatenate(ys) ys = np.concatenate(ys)
plt.xlabel('iter') plt.xlabel('iter')
plt.plot( plt.plot(
xs, ys, label=legend[i * num_metrics + j], linewidth=0.5) xs, ys, label=legend[i * num_metrics + j], linewidth=0.5)
plt.legend() plt.legend()
if args.title is not None: if args.title is not None:
plt.title(args.title) plt.title(args.title)
if args.out is None: if args.out is None:
...@@ -124,6 +111,15 @@ def add_plot_parser(subparsers): ...@@ -124,6 +111,15 @@ def add_plot_parser(subparsers):
nargs='+', nargs='+',
default=['mAP_0.25'], default=['mAP_0.25'],
help='the metric that you want to plot') help='the metric that you want to plot')
parser_plt.add_argument(
'--eval',
action='store_true',
help='whether to plot evaluation metric')
parser_plt.add_argument(
'--eval-interval',
type=str,
default='1',
help='the eval interval when training')
parser_plt.add_argument('--title', type=str, help='title of figure') parser_plt.add_argument('--title', type=str, help='title of figure')
parser_plt.add_argument( parser_plt.add_argument(
'--legend', '--legend',
...@@ -136,8 +132,6 @@ def add_plot_parser(subparsers): ...@@ -136,8 +132,6 @@ def add_plot_parser(subparsers):
parser_plt.add_argument( parser_plt.add_argument(
'--style', type=str, default='dark', help='style of plt') '--style', type=str, default='dark', help='style of plt')
parser_plt.add_argument('--out', type=str, default=None) parser_plt.add_argument('--out', type=str, default=None)
parser_plt.add_argument('--mode', type=str, default='train')
parser_plt.add_argument('--interval', type=int, default=1)
def add_time_parser(subparsers): def add_time_parser(subparsers):
...@@ -174,17 +168,28 @@ def load_json_logs(json_logs): ...@@ -174,17 +168,28 @@ def load_json_logs(json_logs):
for json_log, log_dict in zip(json_logs, log_dicts): for json_log, log_dict in zip(json_logs, log_dicts):
with open(json_log, 'r') as log_file: with open(json_log, 'r') as log_file:
epoch = 1 epoch = 1
for line in log_file: for i, line in enumerate(log_file):
log = json.loads(line.strip()) log = json.loads(line.strip())
val_flag = False
# skip lines only contains one key # skip lines only contains one key
if not len(log) > 1: if not len(log) > 1:
continue continue
if epoch not in log_dict: if epoch not in log_dict:
log_dict[epoch] = defaultdict(list) log_dict[epoch] = defaultdict(list)
for k, v in log.items(): for k, v in log.items():
log_dict[epoch][k].append(v) if '/' in k:
log_dict[epoch][k.split('/')[-1]].append(v)
val_flag = True
elif val_flag:
continue
else:
log_dict[epoch][k].append(v)
if 'epoch' in log.keys(): if 'epoch' in log.keys():
epoch = log['epoch'] + 1 epoch = log['epoch']
return log_dicts return log_dicts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment