Unverified Commit c556d274 authored by twang's avatar twang Committed by GitHub
Browse files

[Fix] Fix log analysis for evaluation (#285)

* Fix log analysis

* Update useful_tools.md
parent 87b05bae
......@@ -7,7 +7,7 @@ You can plot loss/mAP curves given a training log file. Run `pip install seaborn
![loss curve image](../resources/loss_curve.png)
```shell
python tools/analyze_logs.py plot_curve [--keys ${KEYS}] [--title ${TITLE}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}]
python tools/analyze_logs.py plot_curve [--keys ${KEYS}] [--title ${TITLE}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}] [--mode ${MODE}] [--interval ${INTERVAL}]
```
Examples:
......@@ -27,7 +27,10 @@ Examples:
- Compare the bbox mAP of two runs in the same figure.
```shell
python tools/analyze_logs.py plot_curve log1.json log2.json --keys bbox_mAP --legend run1 run2
# evaluate PartA2 and second on KITTI according to Car_3D_moderate_strict
python tools/analyze_logs.py plot_curve tools/logs/PartA2.log.json tools/logs/second.log.json --keys KITTI/Car_3D_moderate_strict --legend PartA2 second --mode eval --interval 1
# evaluate PointPillars for car and 3 classes on KITTI according to Car_3D_moderate_strict
python tools/analyze_logs.py plot_curve tools/logs/pp-3class.log.json tools/logs/pp.log.json --keys KITTI/Car_3D_moderate_strict --legend pp-3class pp --mode eval --interval 2
```
You can also compute the average training speed.
......
......@@ -48,14 +48,14 @@ def plot_curve(log_dicts, args):
epochs = list(log_dict.keys())
for j, metric in enumerate(metrics):
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
if metric not in log_dict[epochs[0]]:
if metric not in log_dict[epochs[args.interval - 1]]:
raise KeyError(
f'{args.json_logs[i]} does not contain metric {metric}')
if 'mAP' in metric:
xs = np.arange(1, max(epochs) + 1)
if args.mode == 'eval':
xs = np.arange(args.interval, max(epochs) + 1, args.interval)
ys = []
for epoch in epochs:
for epoch in epochs[args.interval - 1::args.interval]:
ys += log_dict[epoch][metric]
ax = plt.gca()
ax.set_xticks(xs)
......@@ -64,8 +64,9 @@ def plot_curve(log_dicts, args):
else:
xs = []
ys = []
num_iters_per_epoch = log_dict[epochs[0]]['iter'][-1]
for epoch in epochs:
num_iters_per_epoch = \
log_dict[epochs[args.interval-1]]['iter'][-1]
for epoch in epochs[args.interval - 1::args.interval]:
iters = log_dict[epoch]['iter']
if log_dict[epoch]['mode'][-1] == 'val':
iters = iters[:-1]
......@@ -114,6 +115,8 @@ def add_plot_parser(subparsers):
parser_plt.add_argument(
'--style', type=str, default='dark', help='style of plt')
parser_plt.add_argument('--out', type=str, default=None)
parser_plt.add_argument('--mode', type=str, default='train')
parser_plt.add_argument('--interval', type=int, default=1)
def add_time_parser(subparsers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment