Unverified Commit c556d274 authored by twang's avatar twang Committed by GitHub
Browse files

[Fix] Fix log analysis for evaluation (#285)

* Fix log analysis

* Update useful_tools.md
parent 87b05bae
...@@ -7,7 +7,7 @@ You can plot loss/mAP curves given a training log file. Run `pip install seaborn ...@@ -7,7 +7,7 @@ You can plot loss/mAP curves given a training log file. Run `pip install seaborn
![loss curve image](../resources/loss_curve.png) ![loss curve image](../resources/loss_curve.png)
```shell ```shell
python tools/analyze_logs.py plot_curve [--keys ${KEYS}] [--title ${TITLE}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}] python tools/analyze_logs.py plot_curve [--keys ${KEYS}] [--title ${TITLE}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}] [--mode ${MODE}] [--interval ${INTERVAL}]
``` ```
Examples: Examples:
...@@ -27,7 +27,10 @@ Examples: ...@@ -27,7 +27,10 @@ Examples:
- Compare the bbox mAP of two runs in the same figure. - Compare the bbox mAP of two runs in the same figure.
```shell ```shell
python tools/analyze_logs.py plot_curve log1.json log2.json --keys bbox_mAP --legend run1 run2 # evaluate PartA2 and second on KITTI according to Car_3D_moderate_strict
python tools/analyze_logs.py plot_curve tools/logs/PartA2.log.json tools/logs/second.log.json --keys KITTI/Car_3D_moderate_strict --legend PartA2 second --mode eval --interval 1
# evaluate PointPillars for car and 3 classes on KITTI according to Car_3D_moderate_strict
python tools/analyze_logs.py plot_curve tools/logs/pp-3class.log.json tools/logs/pp.log.json --keys KITTI/Car_3D_moderate_strict --legend pp-3class pp --mode eval --interval 2
``` ```
You can also compute the average training speed. You can also compute the average training speed.
......
...@@ -48,14 +48,14 @@ def plot_curve(log_dicts, args): ...@@ -48,14 +48,14 @@ def plot_curve(log_dicts, args):
epochs = list(log_dict.keys()) epochs = list(log_dict.keys())
for j, metric in enumerate(metrics): for j, metric in enumerate(metrics):
print(f'plot curve of {args.json_logs[i]}, metric is {metric}') print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
if metric not in log_dict[epochs[0]]: if metric not in log_dict[epochs[args.interval - 1]]:
raise KeyError( raise KeyError(
f'{args.json_logs[i]} does not contain metric {metric}') f'{args.json_logs[i]} does not contain metric {metric}')
if 'mAP' in metric: if args.mode == 'eval':
xs = np.arange(1, max(epochs) + 1) xs = np.arange(args.interval, max(epochs) + 1, args.interval)
ys = [] ys = []
for epoch in epochs: for epoch in epochs[args.interval - 1::args.interval]:
ys += log_dict[epoch][metric] ys += log_dict[epoch][metric]
ax = plt.gca() ax = plt.gca()
ax.set_xticks(xs) ax.set_xticks(xs)
...@@ -64,8 +64,9 @@ def plot_curve(log_dicts, args): ...@@ -64,8 +64,9 @@ def plot_curve(log_dicts, args):
else: else:
xs = [] xs = []
ys = [] ys = []
num_iters_per_epoch = log_dict[epochs[0]]['iter'][-1] num_iters_per_epoch = \
for epoch in epochs: log_dict[epochs[args.interval-1]]['iter'][-1]
for epoch in epochs[args.interval - 1::args.interval]:
iters = log_dict[epoch]['iter'] iters = log_dict[epoch]['iter']
if log_dict[epoch]['mode'][-1] == 'val': if log_dict[epoch]['mode'][-1] == 'val':
iters = iters[:-1] iters = iters[:-1]
...@@ -114,6 +115,8 @@ def add_plot_parser(subparsers): ...@@ -114,6 +115,8 @@ def add_plot_parser(subparsers):
parser_plt.add_argument( parser_plt.add_argument(
'--style', type=str, default='dark', help='style of plt') '--style', type=str, default='dark', help='style of plt')
parser_plt.add_argument('--out', type=str, default=None) parser_plt.add_argument('--out', type=str, default=None)
parser_plt.add_argument('--mode', type=str, default='train')
parser_plt.add_argument('--interval', type=int, default=1)
def add_time_parser(subparsers): def add_time_parser(subparsers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment