Commit ed6f3dd2 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support merge_all_iters_to_one_epoch to avoid deadlock when switching epochs,...

support merge_all_iters_to_one_epoch to avoid deadlock when switching epochs, eval last 10 epochs after training
parent 9fdb4435
...@@ -34,6 +34,8 @@ class DatasetTemplate(torch_data.Dataset): ...@@ -34,6 +34,8 @@ class DatasetTemplate(torch_data.Dataset):
self.grid_size = self.data_processor.grid_size self.grid_size = self.data_processor.grid_size
self.voxel_size = self.data_processor.voxel_size self.voxel_size = self.data_processor.voxel_size
self.total_epochs = 0
self._merge_all_iters_to_one_epoch = False
@property @property
def mode(self): def mode(self):
...@@ -65,6 +67,13 @@ class DatasetTemplate(torch_data.Dataset): ...@@ -65,6 +67,13 @@ class DatasetTemplate(torch_data.Dataset):
""" """
def merge_all_iters_to_one_epoch(self, merge=True, epochs=None):
if merge:
self._merge_all_iters_to_one_epoch = True
self.total_epochs = epochs
else:
self._merge_all_iters_to_one_epoch = False
def __len__(self): def __len__(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -330,10 +330,16 @@ class KittiDataset(DatasetTemplate): ...@@ -330,10 +330,16 @@ class KittiDataset(DatasetTemplate):
return ap_result_str, ap_dict return ap_result_str, ap_dict
def __len__(self): def __len__(self):
if self._merge_all_iters_to_one_epoch:
return len(self.kitti_infos) * self.total_epochs
return len(self.kitti_infos) return len(self.kitti_infos)
def __getitem__(self, index): def __getitem__(self, index):
# index = 4 # index = 4
if self._merge_all_iters_to_one_epoch:
index = index % len(self.kitti_infos)
info = copy.deepcopy(self.kitti_infos[index]) info = copy.deepcopy(self.kitti_infos[index])
sample_idx = info['point_cloud']['lidar_idx'] sample_idx = info['point_cloud']['lidar_idx']
......
...@@ -96,8 +96,9 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir ...@@ -96,8 +96,9 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir
cur_epoch_id, cur_ckpt = get_no_evaluated_ckpt(ckpt_dir, ckpt_record_file, args) cur_epoch_id, cur_ckpt = get_no_evaluated_ckpt(ckpt_dir, ckpt_record_file, args)
if cur_epoch_id == -1 or int(float(cur_epoch_id)) < args.start_epoch: if cur_epoch_id == -1 or int(float(cur_epoch_id)) < args.start_epoch:
wait_second = 30 wait_second = 30
print('Wait %s seconds for next check (progress: %.1f / %d minutes): %s \r' if cfg.LOCAL_RANK == 0:
% (wait_second, total_time * 1.0 / 60, args.max_waiting_mins, ckpt_dir), end='', flush=True) print('Wait %s seconds for next check (progress: %.1f / %d minutes): %s \r'
% (wait_second, total_time * 1.0 / 60, args.max_waiting_mins, ckpt_dir), end='', flush=True)
time.sleep(wait_second) time.sleep(wait_second)
total_time += 30 total_time += 30
if total_time > args.max_waiting_mins * 60 and (first_eval is False): if total_time > args.max_waiting_mins * 60 and (first_eval is False):
......
...@@ -9,7 +9,7 @@ from pcdet.models import build_network, model_fn_decorator ...@@ -9,7 +9,7 @@ from pcdet.models import build_network, model_fn_decorator
from train_utils.optimization import build_optimizer, build_scheduler from train_utils.optimization import build_optimizer, build_scheduler
from train_utils.train_utils import train_model from train_utils.train_utils import train_model
import torch.distributed as dist import torch.distributed as dist
from test import repeat_eval_ckpt
from pathlib import Path from pathlib import Path
import argparse import argparse
import datetime import datetime
...@@ -129,9 +129,8 @@ def main(): ...@@ -129,9 +129,8 @@ def main():
model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()]) model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()])
logger.info(model) logger.info(model)
total_iters_each_epoch = len(train_loader) if not args.merge_all_iters_to_one_epoch else len(train_loader) // args.epochs
lr_scheduler, lr_warmup_scheduler = build_scheduler( lr_scheduler, lr_warmup_scheduler = build_scheduler(
optimizer, total_iters_each_epoch=total_iters_each_epoch, total_epochs=args.epochs, optimizer, total_iters_each_epoch=len(train_loader), total_epochs=args.epochs,
last_epoch=last_epoch, optim_cfg=cfg.OPTIMIZATION last_epoch=last_epoch, optim_cfg=cfg.OPTIMIZATION
) )
...@@ -161,6 +160,26 @@ def main(): ...@@ -161,6 +160,26 @@ def main():
logger.info('**********************End training %s/%s(%s)**********************\n\n\n' logger.info('**********************End training %s/%s(%s)**********************\n\n\n'
% (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag)) % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
logger.info('**********************Start evaluation %s/%s(%s)**********************' %
(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
test_set, test_loader, sampler = build_dataloader(
dataset_cfg=cfg.DATA_CONFIG,
class_names=cfg.CLASS_NAMES,
batch_size=args.batch_size,
dist=dist_train, workers=args.workers, logger=logger, training=False
)
eval_output_dir = output_dir / 'eval' / 'eval_with_train'
eval_output_dir.mkdir(parents=True, exist_ok=True)
args.start_epoch = max(args.epochs - 10, 0) # Only evaluate the last 10 epochs
repeat_eval_ckpt(
model.module if dist_train else model,
test_loader, args, eval_output_dir, logger, ckpt_dir,
dist_test=dist_train
)
logger.info('**********************End evaluation %s/%s(%s)**********************' %
(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment