Unverified Commit aa753ec0 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

Update to OpenPCDet v0.6 #1087

 Merge pull request #1087 from sshaoshuai/dev_v0.6
parents 7e8bbe26 beb249e5
CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset_multiframe.yaml
SEQUENCE_CONFIG:
ENABLED: True
SAMPLE_OFFSET: [-1, 0]
TRAIN_WITH_SPEED: False
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: gt_sampling
USE_ROAD_PLANE: False
DB_INFO_PATH:
- waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1_multiframe_-4_to_0.pkl
USE_SHARED_MEMORY: False # set it to True to speed up (it costs about 50GB? shared memory)
DB_DATA_PATH:
- waymo_processed_data_v0_5_0_gt_database_train_sampled_1_multiframe_-4_to_0_global.npy
PREPARE: {
filter_by_min_points: ['Vehicle:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1],
}
SAMPLE_GROUPS: ['Vehicle:15', 'Pedestrian:10', 'Cyclist:10']
NUM_POINT_FEATURES: 6
REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
LIMIT_WHOLE_SCENE: True
FILTER_OBJ_POINTS_BY_TIMESTAMP: True
TIME_RANGE: [0.1, 0.0] # 0.1s-0.0s indicates 2 frames
- NAME: random_world_flip
ALONG_AXIS_LIST: ['x', 'y']
- NAME: random_world_rotation
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
- NAME: random_world_scaling
WORLD_SCALE_RANGE: [0.95, 1.05]
MODEL:
NAME: PVRCNNPlusPlus
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8x
MAP_TO_BEV:
NAME: HeightCompression
NUM_BEV_FEATURES: 256
BACKBONE_2D:
NAME: BaseBEVBackbone
LAYER_NUMS: [5, 5]
LAYER_STRIDES: [1, 2]
NUM_FILTERS: [128, 256]
UPSAMPLE_STRIDES: [1, 2]
NUM_UPSAMPLE_FILTERS: [256, 256]
DENSE_HEAD:
NAME: CenterHead
CLASS_AGNOSTIC: False
CLASS_NAMES_EACH_HEAD: [
[ 'Vehicle', 'Pedestrian', 'Cyclist' ]
]
SHARED_CONV_CHANNEL: 64
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: [ 'center', 'center_z', 'dim', 'rot' ]
HEAD_DICT: {
'center': { 'out_channels': 2, 'num_conv': 2 },
'center_z': { 'out_channels': 1, 'num_conv': 2 },
'dim': { 'out_channels': 3, 'num_conv': 2 },
'rot': { 'out_channels': 2, 'num_conv': 2 },
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'code_weights': [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
}
POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [ -75.2, -75.2, -2, 75.2, 75.2, 4 ]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: 0.7
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500
PFE:
NAME: VoxelSetAbstraction
POINT_SOURCE: raw_points
NUM_KEYPOINTS: 4096
NUM_OUTPUT_FEATURES: 90
SAMPLE_METHOD: SPC
SPC_SAMPLING:
NUM_SECTORS: 6
SAMPLE_RADIUS_WITH_ROI: 1.6
FEATURES_SOURCE: ['bev', 'x_conv3', 'x_conv4', 'raw_points']
SA_LAYER:
raw_points:
NAME: VectorPoolAggregationModuleMSG
NUM_GROUPS: 2
LOCAL_AGGREGATION_TYPE: local_interpolation
NUM_REDUCED_CHANNELS: 2
NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32
MSG_POST_MLPS: [ 32 ]
FILTER_NEIGHBOR_WITH_ROI: True
RADIUS_OF_NEIGHBOR_WITH_ROI: 2.4
GROUP_CFG_0:
NUM_LOCAL_VOXEL: [ 2, 2, 2 ]
MAX_NEIGHBOR_DISTANCE: 0.2
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 32, 32 ]
GROUP_CFG_1:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 0.4
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 32, 32 ]
x_conv3:
DOWNSAMPLE_FACTOR: 4
INPUT_CHANNELS: 64
NAME: VectorPoolAggregationModuleMSG
NUM_GROUPS: 2
LOCAL_AGGREGATION_TYPE: local_interpolation
NUM_REDUCED_CHANNELS: 32
NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32
MSG_POST_MLPS: [128]
FILTER_NEIGHBOR_WITH_ROI: True
RADIUS_OF_NEIGHBOR_WITH_ROI: 4.0
GROUP_CFG_0:
NUM_LOCAL_VOXEL: [3, 3, 3]
MAX_NEIGHBOR_DISTANCE: 1.2
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [64, 64]
GROUP_CFG_1:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 2.4
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 64, 64 ]
x_conv4:
DOWNSAMPLE_FACTOR: 8
INPUT_CHANNELS: 64
NAME: VectorPoolAggregationModuleMSG
NUM_GROUPS: 2
LOCAL_AGGREGATION_TYPE: local_interpolation
NUM_REDUCED_CHANNELS: 32
NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32
MSG_POST_MLPS: [ 128 ]
FILTER_NEIGHBOR_WITH_ROI: True
RADIUS_OF_NEIGHBOR_WITH_ROI: 6.4
GROUP_CFG_0:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 2.4
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 64, 64 ]
GROUP_CFG_1:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 4.8
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 64, 64 ]
POINT_HEAD:
NAME: PointHeadSimple
CLS_FC: [256, 256]
CLASS_AGNOSTIC: True
USE_POINT_FEATURES_BEFORE_FUSION: True
TARGET_CONFIG:
GT_EXTRA_WIDTH: [0.2, 0.2, 0.2]
LOSS_CONFIG:
LOSS_REG: smooth-l1
LOSS_WEIGHTS: {
'point_cls_weight': 1.0,
}
ROI_HEAD:
NAME: PVRCNNHead
CLASS_AGNOSTIC: True
SHARED_FC: [256, 256]
CLS_FC: [256, 256]
REG_FC: [256, 256]
DP_RATIO: 0.3
NMS_CONFIG:
TRAIN:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 9000
NMS_POST_MAXSIZE: 512
NMS_THRESH: 0.8
TEST:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 1024
NMS_POST_MAXSIZE: 100
NMS_THRESH: 0.7
SCORE_THRESH: 0.1
# NMS_PRE_MAXSIZE: 4096
# NMS_POST_MAXSIZE: 500
# NMS_THRESH: 0.85
ROI_GRID_POOL:
GRID_SIZE: 6
NAME: VectorPoolAggregationModuleMSG
NUM_GROUPS: 2
LOCAL_AGGREGATION_TYPE: voxel_random_choice
NUM_REDUCED_CHANNELS: 30
NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32
MSG_POST_MLPS: [ 128 ]
GROUP_CFG_0:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 0.8
NEIGHBOR_NSAMPLE: 32
POST_MLPS: [ 64, 64 ]
GROUP_CFG_1:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 1.6
NEIGHBOR_NSAMPLE: 32
POST_MLPS: [ 64, 64 ]
TARGET_CONFIG:
BOX_CODER: ResidualCoder
ROI_PER_IMAGE: 128
FG_RATIO: 0.5
SAMPLE_ROI_BY_EACH_CLASS: True
CLS_SCORE_TYPE: roi_iou
CLS_FG_THRESH: 0.75
CLS_BG_THRESH: 0.25
CLS_BG_THRESH_LO: 0.1
HARD_BG_RATIO: 0.8
REG_FG_THRESH: 0.55
LOSS_CONFIG:
CLS_LOSS: BinaryCrossEntropy
REG_LOSS: smooth-l1
CORNER_LOSS_REGULARIZATION: True
LOSS_WEIGHTS: {
'rcnn_cls_weight': 1.0,
'rcnn_reg_weight': 1.0,
'rcnn_corner_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: waymo
NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.7
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 2
NUM_EPOCHS: 30
OPTIMIZER: adam_onecycle
LR: 0.01
WEIGHT_DECAY: 0.001
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
\ No newline at end of file
...@@ -3,6 +3,7 @@ import pickle as pkl ...@@ -3,6 +3,7 @@ import pickle as pkl
from pathlib import Path from pathlib import Path
import tqdm import tqdm
import copy import copy
import os
def create_integrated_db_with_infos(args, root_path): def create_integrated_db_with_infos(args, root_path):
...@@ -13,8 +14,8 @@ def create_integrated_db_with_infos(args, root_path): ...@@ -13,8 +14,8 @@ def create_integrated_db_with_infos(args, root_path):
""" """
# prepare # prepare
db_infos_path = root_path / args.src_db_info db_infos_path = args.src_db_info
db_info_global_path = str(db_infos_path)[:-4] + '_global' + '.pkl' db_info_global_path = db_infos_path
global_db_path = root_path / (args.new_db_name + '.npy') global_db_path = root_path / (args.new_db_name + '.npy')
db_infos = pkl.load(open(db_infos_path, 'rb')) db_infos = pkl.load(open(db_infos_path, 'rb'))
...@@ -29,6 +30,12 @@ def create_integrated_db_with_infos(args, root_path): ...@@ -29,6 +30,12 @@ def create_integrated_db_with_infos(args, root_path):
obj_points = np.fromfile(str(obj_path), dtype=np.float32).reshape( obj_points = np.fromfile(str(obj_path), dtype=np.float32).reshape(
[-1, args.num_point_features]) [-1, args.num_point_features])
num_points = obj_points.shape[0] num_points = obj_points.shape[0]
if num_points != info['num_points_in_gt']:
obj_points = np.fromfile(str(obj_path), dtype=np.float64).reshape([-1, args.num_point_features])
num_points = obj_points.shape[0]
obj_points = obj_points.astype(np.float32)
assert num_points == info['num_points_in_gt']
db_info_global[category][idx]['global_data_offset'] = (start_idx, start_idx + num_points) db_info_global[category][idx]['global_data_offset'] = (start_idx, start_idx + num_points)
start_idx += num_points start_idx += num_points
global_db_list.append(obj_points) global_db_list.append(obj_points)
...@@ -65,17 +72,14 @@ if __name__ == '__main__': ...@@ -65,17 +72,14 @@ if __name__ == '__main__':
import argparse import argparse
parser = argparse.ArgumentParser(description='arg parser') parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--root_path', type=str, default=None, help='specify the root path') parser.add_argument('--src_db_info', type=str, default='../../data/waymo/waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1_multiframe_-4_to_0_tail_parallel.pkl', help='')
parser.add_argument('--src_db_info', type=str, default='waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl', help='') parser.add_argument('--new_db_name', type=str, default='waymo_processed_data_v0_5_0_gt_database_train_sampled_1_multiframe_-4_to_0_tail_parallel_global', help='')
parser.add_argument('--new_db_name', type=str, default='waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global', help='') parser.add_argument('--num_point_features', type=int, default=6, help='number of feature channels for points')
parser.add_argument('--num_point_features', type=int, default=5, parser.add_argument('--class_name', type=str, default='Vehicle', help='category name for verification')
help='number of feature channels for points')
parser.add_argument('--class_name', type=str, default='Vehicle',
help='category name for verification')
args = parser.parse_args() args = parser.parse_args()
root_path = Path(args.root_path) root_path = Path(os.path.dirname(args.src_db_info))
db_infos_global, whole_db = create_integrated_db_with_infos(args, root_path) db_infos_global, whole_db = create_integrated_db_with_infos(args, root_path)
# simple verify # simple verify
......
...@@ -44,6 +44,12 @@ def parse_config(): ...@@ -44,6 +44,12 @@ def parse_config():
parser.add_argument('--num_epochs_to_eval', type=int, default=0, help='number of checkpoints to be evaluated') parser.add_argument('--num_epochs_to_eval', type=int, default=0, help='number of checkpoints to be evaluated')
parser.add_argument('--save_to_file', action='store_true', default=False, help='') parser.add_argument('--save_to_file', action='store_true', default=False, help='')
parser.add_argument('--use_tqdm_to_record', action='store_true', default=False, help='if True, the intermediate losses will not be logged to file, only tqdm will be used')
parser.add_argument('--logger_iter_interval', type=int, default=50, help='')
parser.add_argument('--ckpt_save_time_interval', type=int, default=300, help='in terms of seconds')
parser.add_argument('--wo_gpu_stat', action='store_true', help='')
args = parser.parse_args() args = parser.parse_args()
cfg_from_yaml_file(args.cfg_file, cfg) cfg_from_yaml_file(args.cfg_file, cfg)
...@@ -131,13 +137,19 @@ def main(): ...@@ -131,13 +137,19 @@ def main():
it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist_train, optimizer=optimizer, logger=logger) it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist_train, optimizer=optimizer, logger=logger)
last_epoch = start_epoch + 1 last_epoch = start_epoch + 1
else: else:
ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth')) ckpt_list = glob.glob(str(ckpt_dir / '*.pth'))
if len(ckpt_list) > 0: if len(ckpt_list) > 0:
ckpt_list.sort(key=os.path.getmtime) ckpt_list.sort(key=os.path.getmtime)
while len(ckpt_list) > 0:
try:
it, start_epoch = model.load_params_with_optimizer( it, start_epoch = model.load_params_with_optimizer(
ckpt_list[-1], to_cpu=dist_train, optimizer=optimizer, logger=logger ckpt_list[-1], to_cpu=dist_train, optimizer=optimizer, logger=logger
) )
last_epoch = start_epoch + 1 last_epoch = start_epoch + 1
break
except:
ckpt_list = ckpt_list[:-1]
model.train() # before wrap to DistributedDataParallel to support fixed some parameters model.train() # before wrap to DistributedDataParallel to support fixed some parameters
if dist_train: if dist_train:
...@@ -152,6 +164,7 @@ def main(): ...@@ -152,6 +164,7 @@ def main():
# -----------------------start training--------------------------- # -----------------------start training---------------------------
logger.info('**********************Start training %s/%s(%s)**********************' logger.info('**********************Start training %s/%s(%s)**********************'
% (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag)) % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
train_model( train_model(
model, model,
optimizer, optimizer,
...@@ -169,7 +182,12 @@ def main(): ...@@ -169,7 +182,12 @@ def main():
lr_warmup_scheduler=lr_warmup_scheduler, lr_warmup_scheduler=lr_warmup_scheduler,
ckpt_save_interval=args.ckpt_save_interval, ckpt_save_interval=args.ckpt_save_interval,
max_ckpt_save_num=args.max_ckpt_save_num, max_ckpt_save_num=args.max_ckpt_save_num,
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
logger=logger,
logger_iter_interval=args.logger_iter_interval,
ckpt_save_time_interval=args.ckpt_save_time_interval,
use_logger_to_record=not args.use_tqdm_to_record,
show_gpu_stat=not args.wo_gpu_stat
) )
if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory: if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory:
......
...@@ -9,10 +9,15 @@ from pcdet.utils import common_utils, commu_utils ...@@ -9,10 +9,15 @@ from pcdet.utils import common_utils, commu_utils
def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg, def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False): rank, tbar, total_it_each_epoch, dataloader_iter, tb_log=None, leave_pbar=False,
use_logger_to_record=False, logger=None, logger_iter_interval=50, cur_epoch=None,
total_epochs=None, ckpt_save_dir=None, ckpt_save_time_interval=300, show_gpu_stat=False):
if total_it_each_epoch == len(train_loader): if total_it_each_epoch == len(train_loader):
dataloader_iter = iter(train_loader) dataloader_iter = iter(train_loader)
ckpt_save_cnt = 1
start_it = accumulated_iter % total_it_each_epoch
if rank == 0: if rank == 0:
pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True) pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)
data_time = common_utils.AverageMeter() data_time = common_utils.AverageMeter()
...@@ -20,7 +25,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -20,7 +25,7 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
forward_time = common_utils.AverageMeter() forward_time = common_utils.AverageMeter()
end = time.time() end = time.time()
for cur_it in range(total_it_each_epoch): for cur_it in range(start_it, total_it_each_epoch):
try: try:
batch = next(dataloader_iter) batch = next(dataloader_iter)
except StopIteration: except StopIteration:
...@@ -66,21 +71,54 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -66,21 +71,54 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
data_time.update(avg_data_time) data_time.update(avg_data_time)
forward_time.update(avg_forward_time) forward_time.update(avg_forward_time)
batch_time.update(avg_batch_time) batch_time.update(avg_batch_time)
disp_dict.update({ disp_dict.update({
'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})', 'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})',
'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})' 'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})'
}) })
if use_logger_to_record:
if accumulated_iter % logger_iter_interval == 0 or cur_it == start_it or cur_it + 1 == total_it_each_epoch:
trained_time_past_all = tbar.format_dict['elapsed']
second_each_iter = pbar.format_dict['elapsed'] / max(cur_it - start_it + 1, 1.0)
trained_time_each_epoch = pbar.format_dict['elapsed']
remaining_second_each_epoch = second_each_iter * (total_it_each_epoch - cur_it)
remaining_second_all = second_each_iter * ((total_epochs - cur_epoch) * total_it_each_epoch - cur_it)
disp_str = ', '.join([f'{key}={val}' for key, val in disp_dict.items() if key != 'lr'])
disp_str += f', lr={disp_dict["lr"]}'
batch_size = batch.get('batch_size', None)
logger.info(f'epoch: {cur_epoch}/{total_epochs}, acc_iter={accumulated_iter}, cur_iter={cur_it}/{total_it_each_epoch}, batch_size={batch_size}, '
f'time_cost(epoch): {tbar.format_interval(trained_time_each_epoch)}/{tbar.format_interval(remaining_second_each_epoch)}, '
f'time_cost(all): {tbar.format_interval(trained_time_past_all)}/{tbar.format_interval(remaining_second_all)}, '
f'{disp_str}')
if show_gpu_stat and accumulated_iter % (3 * logger_iter_interval) == 0:
# To show the GPU utilization, please install gpustat through "pip install gpustat"
gpu_info = os.popen('gpustat').read()
logger.info(gpu_info)
else:
pbar.update() pbar.update()
pbar.set_postfix(dict(total_it=accumulated_iter)) pbar.set_postfix(dict(total_it=accumulated_iter))
tbar.set_postfix(disp_dict) tbar.set_postfix(disp_dict)
tbar.refresh() # tbar.refresh()
if tb_log is not None: if tb_log is not None:
tb_log.add_scalar('train/loss', loss, accumulated_iter) tb_log.add_scalar('train/loss', loss, accumulated_iter)
tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter) tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)
for key, val in tb_dict.items(): for key, val in tb_dict.items():
tb_log.add_scalar('train/' + key, val, accumulated_iter) tb_log.add_scalar('train/' + key, val, accumulated_iter)
# save intermediate ckpt every {ckpt_save_time_interval} seconds
time_past_this_epoch = pbar.format_dict['elapsed']
if time_past_this_epoch // ckpt_save_time_interval >= ckpt_save_cnt:
ckpt_name = ckpt_save_dir / 'latest_model'
save_checkpoint(
checkpoint_state(model, optimizer, cur_epoch, accumulated_iter), filename=ckpt_name,
)
logger.info(f'Save latest model to {ckpt_name}')
ckpt_save_cnt += 1
if rank == 0: if rank == 0:
pbar.close() pbar.close()
return accumulated_iter return accumulated_iter
...@@ -89,7 +127,8 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac ...@@ -89,7 +127,8 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_cfg, def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_cfg,
start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None, start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None,
lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50, lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50,
merge_all_iters_to_one_epoch=False): merge_all_iters_to_one_epoch=False,
use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None, show_gpu_stat=False):
accumulated_iter = start_iter accumulated_iter = start_iter
with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar: with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar:
total_it_each_epoch = len(train_loader) total_it_each_epoch = len(train_loader)
...@@ -115,7 +154,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_ ...@@ -115,7 +154,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
rank=rank, tbar=tbar, tb_log=tb_log, rank=rank, tbar=tbar, tb_log=tb_log,
leave_pbar=(cur_epoch + 1 == total_epochs), leave_pbar=(cur_epoch + 1 == total_epochs),
total_it_each_epoch=total_it_each_epoch, total_it_each_epoch=total_it_each_epoch,
dataloader_iter=dataloader_iter dataloader_iter=dataloader_iter,
cur_epoch=cur_epoch, total_epochs=total_epochs,
use_logger_to_record=use_logger_to_record,
logger=logger, logger_iter_interval=logger_iter_interval,
ckpt_save_dir=ckpt_save_dir, ckpt_save_time_interval=ckpt_save_time_interval,
show_gpu_stat=show_gpu_stat
) )
# save trained model # save trained model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment