Unverified Commit 4497eb59 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

abstract load_data_to_gpu function in train/test (#116)

* abstract load_data_to_gpu function in train/test
parent be0507ce
......@@ -144,6 +144,7 @@ class DatasetTemplate(torch_data.Dataset):
ret = {}
for key, val in data_dict.items():
try:
if key in ['voxels', 'voxel_num_points']:
ret[key] = np.concatenate(val, axis=0)
elif key in ['points', 'voxel_coords']:
......@@ -160,6 +161,9 @@ class DatasetTemplate(torch_data.Dataset):
ret[key] = batch_gt_boxes3d
else:
ret[key] = np.stack(val, axis=0)
except:
print('Error in collate_batch: key=%s' % key)
raise TypeError
ret['batch_size'] = batch_size
return ret
......@@ -11,16 +11,20 @@ def build_network(model_cfg, num_class, dataset):
return model
def model_fn_decorator():
ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict'])
def model_func(model, batch_dict):
def load_data_to_gpu(batch_dict):
for key, val in batch_dict.items():
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id']:
if key in ['frame_id', 'metadata', 'calib', 'image_shape']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()
def model_fn_decorator():
ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict'])
def model_func(model, batch_dict):
load_data_to_gpu(batch_dict)
ret_dict, tb_dict, disp_dict = model(batch_dict)
loss = ret_dict['loss'].mean()
......
......@@ -4,6 +4,7 @@ import pickle
import numpy as np
import torch
from pcdet.utils import common_utils
from pcdet.models import load_data_to_gpu
def statistics_info(cfg, ret_dict, metric, disp_dict):
......@@ -51,13 +52,7 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True)
start_time = time.time()
for i, batch_dict in enumerate(dataloader):
for key, val in batch_dict.items():
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id', 'calib', 'image_shape']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()
load_data_to_gpu(batch_dict)
with torch.no_grad():
pred_dicts, ret_dict = model(batch_dict)
disp_dict = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment