Unverified Commit 4497eb59 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

abstract load_data_to_gpu function in train/test (#116)

* abstract load_data_to_gpu function in train/test
parent be0507ce
...@@ -144,22 +144,26 @@ class DatasetTemplate(torch_data.Dataset): ...@@ -144,22 +144,26 @@ class DatasetTemplate(torch_data.Dataset):
ret = {} ret = {}
for key, val in data_dict.items(): for key, val in data_dict.items():
if key in ['voxels', 'voxel_num_points']: try:
ret[key] = np.concatenate(val, axis=0) if key in ['voxels', 'voxel_num_points']:
elif key in ['points', 'voxel_coords']: ret[key] = np.concatenate(val, axis=0)
coors = [] elif key in ['points', 'voxel_coords']:
for i, coor in enumerate(val): coors = []
coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i) for i, coor in enumerate(val):
coors.append(coor_pad) coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i)
ret[key] = np.concatenate(coors, axis=0) coors.append(coor_pad)
elif key in ['gt_boxes']: ret[key] = np.concatenate(coors, axis=0)
max_gt = max([len(x) for x in val]) elif key in ['gt_boxes']:
batch_gt_boxes3d = np.zeros((batch_size, max_gt, val[0].shape[-1]), dtype=np.float32) max_gt = max([len(x) for x in val])
for k in range(batch_size): batch_gt_boxes3d = np.zeros((batch_size, max_gt, val[0].shape[-1]), dtype=np.float32)
batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k] for k in range(batch_size):
ret[key] = batch_gt_boxes3d batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k]
else: ret[key] = batch_gt_boxes3d
ret[key] = np.stack(val, axis=0) else:
ret[key] = np.stack(val, axis=0)
except:
print('Error in collate_batch: key=%s' % key)
raise TypeError
ret['batch_size'] = batch_size ret['batch_size'] = batch_size
return ret return ret
...@@ -11,16 +11,20 @@ def build_network(model_cfg, num_class, dataset): ...@@ -11,16 +11,20 @@ def build_network(model_cfg, num_class, dataset):
return model return model
def load_data_to_gpu(batch_dict):
for key, val in batch_dict.items():
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id', 'metadata', 'calib', 'image_shape']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()
def model_fn_decorator(): def model_fn_decorator():
ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict']) ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict'])
def model_func(model, batch_dict): def model_func(model, batch_dict):
for key, val in batch_dict.items(): load_data_to_gpu(batch_dict)
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()
ret_dict, tb_dict, disp_dict = model(batch_dict) ret_dict, tb_dict, disp_dict = model(batch_dict)
loss = ret_dict['loss'].mean() loss = ret_dict['loss'].mean()
......
...@@ -4,6 +4,7 @@ import pickle ...@@ -4,6 +4,7 @@ import pickle
import numpy as np import numpy as np
import torch import torch
from pcdet.utils import common_utils from pcdet.utils import common_utils
from pcdet.models import load_data_to_gpu
def statistics_info(cfg, ret_dict, metric, disp_dict): def statistics_info(cfg, ret_dict, metric, disp_dict):
...@@ -51,13 +52,7 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa ...@@ -51,13 +52,7 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True) progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True)
start_time = time.time() start_time = time.time()
for i, batch_dict in enumerate(dataloader): for i, batch_dict in enumerate(dataloader):
for key, val in batch_dict.items(): load_data_to_gpu(batch_dict)
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id', 'calib', 'image_shape']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()
with torch.no_grad(): with torch.no_grad():
pred_dicts, ret_dict = model(batch_dict) pred_dicts, ret_dict = model(batch_dict)
disp_dict = {} disp_dict = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment