Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
...@@ -6,14 +6,13 @@ from typing import List, Tuple, Union ...@@ -6,14 +6,13 @@ from typing import List, Tuple, Union
import mmcv import mmcv
import numpy as np import numpy as np
from mmdet3d.core.bbox import points_cam2img
from mmdet3d.datasets import NuScenesDataset
from nuscenes.nuscenes import NuScenes from nuscenes.nuscenes import NuScenes
from nuscenes.utils.geometry_utils import view_points from nuscenes.utils.geometry_utils import view_points
from pyquaternion import Quaternion from pyquaternion import Quaternion
from shapely.geometry import MultiPoint, box from shapely.geometry import MultiPoint, box
from mmdet3d.core.bbox import points_cam2img
from mmdet3d.datasets import NuScenesDataset
nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', nus_categories = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
'barrier') 'barrier')
...@@ -324,9 +323,9 @@ def obtain_sensor2top(nusc, ...@@ -324,9 +323,9 @@ def obtain_sensor2top(nusc,
l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix
e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix
R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ ( R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ (
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ ( T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ (
np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
) + l2e_t @ np.linalg.inv(l2e_r_mat).T ) + l2e_t @ np.linalg.inv(l2e_r_mat).T
sweep['sensor2lidar_rotation'] = R.T # points @ R.T + T sweep['sensor2lidar_rotation'] = R.T # points @ R.T + T
...@@ -420,8 +419,8 @@ def get_2d_boxes(nusc, ...@@ -420,8 +419,8 @@ def get_2d_boxes(nusc,
sd_rec = nusc.get('sample_data', sample_data_token) sd_rec = nusc.get('sample_data', sample_data_token)
assert sd_rec[ assert sd_rec[
'sensor_modality'] == 'camera', 'Error: get_2d_boxes only works' \ 'sensor_modality'] == 'camera', 'Error: get_2d_boxes only works' \
' for camera sample_data!' ' for camera sample_data!'
if not sd_rec['is_key_frame']: if not sd_rec['is_key_frame']:
raise ValueError( raise ValueError(
'The 2D re-projections are available only for keyframes.') 'The 2D re-projections are available only for keyframes.')
...@@ -532,7 +531,7 @@ def get_2d_boxes(nusc, ...@@ -532,7 +531,7 @@ def get_2d_boxes(nusc,
def post_process_coords( def post_process_coords(
corner_coords: List, imsize: Tuple[int, int] = (1600, 900) corner_coords: List, imsize: Tuple[int, int] = (1600, 900)
) -> Union[Tuple[float, float, float, float], None]: ) -> Union[Tuple[float, float, float, float], None]:
"""Get the intersection of the convex hull of the reprojected bbox corners """Get the intersection of the convex hull of the reprojected bbox corners
and the image canvas, return None if no intersection. and the image canvas, return None if no intersection.
......
...@@ -181,7 +181,7 @@ class S3DISSegData(object): ...@@ -181,7 +181,7 @@ class S3DISSegData(object):
self.ignore_index = len(self.cat_ids) self.ignore_index = len(self.cat_ids)
self.cat_id2class = np.ones((self.all_ids.shape[0],), dtype=np.int) * \ self.cat_id2class = np.ones((self.all_ids.shape[0],), dtype=np.int) * \
self.ignore_index self.ignore_index
for i, cat_id in enumerate(self.cat_ids): for i, cat_id in enumerate(self.cat_ids):
self.cat_id2class[cat_id] = i self.cat_id2class[cat_id] = i
...@@ -221,7 +221,7 @@ class S3DISSegData(object): ...@@ -221,7 +221,7 @@ class S3DISSegData(object):
""" """
num_classes = len(self.cat_ids) num_classes = len(self.cat_ids)
num_point_all = [] num_point_all = []
label_weight = np.zeros((num_classes + 1, )) # ignore_index label_weight = np.zeros((num_classes + 1,)) # ignore_index
for data_info in self.data_infos: for data_info in self.data_infos:
label = self._convert_to_label( label = self._convert_to_label(
osp.join(self.data_root, data_info['pts_semantic_mask_path'])) osp.join(self.data_root, data_info['pts_semantic_mask_path']))
......
...@@ -231,7 +231,7 @@ class ScanNetSegData(object): ...@@ -231,7 +231,7 @@ class ScanNetSegData(object):
self.ignore_index = len(self.cat_ids) self.ignore_index = len(self.cat_ids)
self.cat_id2class = np.ones((self.all_ids.shape[0],), dtype=np.int) * \ self.cat_id2class = np.ones((self.all_ids.shape[0],), dtype=np.int) * \
self.ignore_index self.ignore_index
for i, cat_id in enumerate(self.cat_ids): for i, cat_id in enumerate(self.cat_ids):
self.cat_id2class[cat_id] = i self.cat_id2class[cat_id] = i
...@@ -273,7 +273,7 @@ class ScanNetSegData(object): ...@@ -273,7 +273,7 @@ class ScanNetSegData(object):
""" """
num_classes = len(self.cat_ids) num_classes = len(self.cat_ids)
num_point_all = [] num_point_all = []
label_weight = np.zeros((num_classes + 1, )) # ignore_index label_weight = np.zeros((num_classes + 1,)) # ignore_index
for data_info in self.data_infos: for data_info in self.data_infos:
label = self._convert_to_label( label = self._convert_to_label(
osp.join(self.data_root, data_info['pts_semantic_mask_path'])) osp.join(self.data_root, data_info['pts_semantic_mask_path']))
......
...@@ -47,7 +47,7 @@ class SUNRGBDInstance(object): ...@@ -47,7 +47,7 @@ class SUNRGBDInstance(object):
# z_size (height) in our depth coordinate system, # z_size (height) in our depth coordinate system,
# l corresponds to the size along the x axis # l corresponds to the size along the x axis
self.size = np.array([data[9], data[8], data[10]]) * 2 self.size = np.array([data[9], data[8], data[10]]) * 2
self.orientation = np.zeros((3, )) self.orientation = np.zeros((3,))
self.orientation[0] = data[11] self.orientation[0] = data[11]
self.orientation[1] = data[12] self.orientation[1] = data[12]
self.heading_angle = np.arctan2(self.orientation[1], self.heading_angle = np.arctan2(self.orientation[1],
...@@ -187,12 +187,12 @@ class SUNRGBDData(object): ...@@ -187,12 +187,12 @@ class SUNRGBDData(object):
obj.box2d.reshape(1, 4) for obj in obj_list obj.box2d.reshape(1, 4) for obj in obj_list
if obj.classname in self.cat2label.keys() if obj.classname in self.cat2label.keys()
], ],
axis=0) axis=0)
annotations['location'] = np.concatenate([ annotations['location'] = np.concatenate([
obj.centroid.reshape(1, 3) for obj in obj_list obj.centroid.reshape(1, 3) for obj in obj_list
if obj.classname in self.cat2label.keys() if obj.classname in self.cat2label.keys()
], ],
axis=0) axis=0)
annotations['dimensions'] = 2 * np.array([ annotations['dimensions'] = 2 * np.array([
[obj.length, obj.width, obj.height] for obj in obj_list [obj.length, obj.width, obj.height] for obj in obj_list
if obj.classname in self.cat2label.keys() if obj.classname in self.cat2label.keys()
......
...@@ -141,8 +141,8 @@ class Waymo2KITTI(object): ...@@ -141,8 +141,8 @@ class Waymo2KITTI(object):
""" """
for img in frame.images: for img in frame.images:
img_path = f'{self.image_save_dir}{str(img.name - 1)}/' + \ img_path = f'{self.image_save_dir}{str(img.name - 1)}/' + \
f'{self.prefix}{str(file_idx).zfill(3)}' + \ f'{self.prefix}{str(file_idx).zfill(3)}' + \
f'{str(frame_idx).zfill(3)}.jpg' f'{str(frame_idx).zfill(3)}.jpg'
with open(img_path, 'wb') as fp: with open(img_path, 'wb') as fp:
fp.write(img.image) fp.write(img.image)
...@@ -171,7 +171,7 @@ class Waymo2KITTI(object): ...@@ -171,7 +171,7 @@ class Waymo2KITTI(object):
self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam
if camera.name == 1: # FRONT = 1, see dataset.proto for details if camera.name == 1: # FRONT = 1, see dataset.proto for details
self.T_velo_to_front_cam = Tr_velo_to_cam.copy() self.T_velo_to_front_cam = Tr_velo_to_cam.copy()
Tr_velo_to_cam = Tr_velo_to_cam[:3, :].reshape((12, )) Tr_velo_to_cam = Tr_velo_to_cam[:3, :].reshape((12,))
Tr_velo_to_cams.append([f'{i:e}' for i in Tr_velo_to_cam]) Tr_velo_to_cams.append([f'{i:e}' for i in Tr_velo_to_cam])
# intrinsic parameters # intrinsic parameters
...@@ -189,11 +189,11 @@ class Waymo2KITTI(object): ...@@ -189,11 +189,11 @@ class Waymo2KITTI(object):
# camera 0 is unknown in the proto # camera 0 is unknown in the proto
for i in range(5): for i in range(5):
calib_context += 'P' + str(i) + ': ' + \ calib_context += 'P' + str(i) + ': ' + \
' '.join(camera_calibs[i]) + '\n' ' '.join(camera_calibs[i]) + '\n'
calib_context += 'R0_rect' + ': ' + ' '.join(R0_rect) + '\n' calib_context += 'R0_rect' + ': ' + ' '.join(R0_rect) + '\n'
for i in range(5): for i in range(5):
calib_context += 'Tr_velo_to_cam_' + str(i) + ': ' + \ calib_context += 'Tr_velo_to_cam_' + str(i) + ': ' + \
' '.join(Tr_velo_to_cams[i]) + '\n' ' '.join(Tr_velo_to_cams[i]) + '\n'
with open( with open(
f'{self.calib_save_dir}/{self.prefix}' + f'{self.calib_save_dir}/{self.prefix}' +
...@@ -253,7 +253,7 @@ class Waymo2KITTI(object): ...@@ -253,7 +253,7 @@ class Waymo2KITTI(object):
(points, intensity, elongation, mask_indices)) (points, intensity, elongation, mask_indices))
pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \ pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \
f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.bin' f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.bin'
point_cloud.astype(np.float32).tofile(pc_path) point_cloud.astype(np.float32).tofile(pc_path)
def save_label(self, frame, file_idx, frame_idx): def save_label(self, frame, file_idx, frame_idx):
...@@ -321,7 +321,7 @@ class Waymo2KITTI(object): ...@@ -321,7 +321,7 @@ class Waymo2KITTI(object):
# project bounding box to the virtual reference frame # project bounding box to the virtual reference frame
pt_ref = self.T_velo_to_front_cam @ \ pt_ref = self.T_velo_to_front_cam @ \
np.array([x, y, z, 1]).reshape((4, 1)) np.array([x, y, z, 1]).reshape((4, 1))
x, y, z, _ = pt_ref.flatten().tolist() x, y, z, _ = pt_ref.flatten().tolist()
rotation_y = -obj.box.heading - np.pi / 2 rotation_y = -obj.box.heading - np.pi / 2
...@@ -333,13 +333,13 @@ class Waymo2KITTI(object): ...@@ -333,13 +333,13 @@ class Waymo2KITTI(object):
alpha = -10 alpha = -10
line = my_type + \ line = my_type + \
' {} {} {} {} {} {} {} {} {} {} {} {} {} {}\n'.format( ' {} {} {} {} {} {} {} {} {} {} {} {} {} {}\n'.format(
round(truncated, 2), occluded, round(alpha, 2), round(truncated, 2), occluded, round(alpha, 2),
round(bounding_box[0], 2), round(bounding_box[1], 2), round(bounding_box[0], 2), round(bounding_box[1], 2),
round(bounding_box[2], 2), round(bounding_box[3], 2), round(bounding_box[2], 2), round(bounding_box[3], 2),
round(height, 2), round(width, 2), round(length, 2), round(height, 2), round(width, 2), round(length, 2),
round(x, 2), round(y, 2), round(z, 2), round(x, 2), round(y, 2), round(z, 2),
round(rotation_y, 2)) round(rotation_y, 2))
if self.save_track_id: if self.save_track_id:
line_all = line[:-1] + ' ' + name + ' ' + track_id + '\n' line_all = line[:-1] + ' ' + name + ' ' + track_id + '\n'
......
...@@ -13,12 +13,12 @@ except ImportError: ...@@ -13,12 +13,12 @@ except ImportError:
def mmdet3d2torchserve( def mmdet3d2torchserve(
config_file: str, config_file: str,
checkpoint_file: str, checkpoint_file: str,
output_folder: str, output_folder: str,
model_name: str, model_name: str,
model_version: str = '1.0', model_version: str = '1.0',
force: bool = False, force: bool = False,
): ):
"""Converts MMDetection3D model (config + checkpoint) to TorchServe `.mar`. """Converts MMDetection3D model (config + checkpoint) to TorchServe `.mar`.
...@@ -83,8 +83,8 @@ def parse_args(): ...@@ -83,8 +83,8 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help='If not None, used for naming the `{model_name}.mar`' help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.' 'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.') 'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument( parser.add_argument(
'--model-version', '--model-version',
type=str, type=str,
......
...@@ -4,10 +4,9 @@ import os ...@@ -4,10 +4,9 @@ import os
import numpy as np import numpy as np
import torch import torch
from ts.torch_handler.base_handler import BaseHandler
from mmdet3d.apis import inference_detector, init_model from mmdet3d.apis import inference_detector, init_model
from mmdet3d.core.points import get_points_type from mmdet3d.core.points import get_points_type
from ts.torch_handler.base_handler import BaseHandler
class MMdet3dHandler(BaseHandler): class MMdet3dHandler(BaseHandler):
......
...@@ -2,7 +2,6 @@ from argparse import ArgumentParser ...@@ -2,7 +2,6 @@ from argparse import ArgumentParser
import numpy as np import numpy as np
import requests import requests
from mmdet3d.apis import inference_detector, init_model from mmdet3d.apis import inference_detector, init_model
......
...@@ -7,7 +7,6 @@ from pathlib import Path ...@@ -7,7 +7,6 @@ from pathlib import Path
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv import Config, DictAction, mkdir_or_exist from mmcv import Config, DictAction, mkdir_or_exist
from mmdet3d.core.bbox import (Box3DMode, CameraInstance3DBoxes, Coord3DMode, from mmdet3d.core.bbox import (Box3DMode, CameraInstance3DBoxes, Coord3DMode,
DepthInstance3DBoxes, LiDARInstance3DBoxes) DepthInstance3DBoxes, LiDARInstance3DBoxes)
from mmdet3d.core.visualizer import (show_multi_modality_result, show_result, from mmdet3d.core.visualizer import (show_multi_modality_result, show_result,
...@@ -42,17 +41,17 @@ def parse_args(): ...@@ -42,17 +41,17 @@ def parse_args():
'--online', '--online',
action='store_true', action='store_true',
help='Whether to perform online visualization. Note that you often ' help='Whether to perform online visualization. Note that you often '
'need a monitor to do so.') 'need a monitor to do so.')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to ' 'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
args = parser.parse_args() args = parser.parse_args()
return args return args
......
...@@ -3,9 +3,8 @@ import argparse ...@@ -3,9 +3,8 @@ import argparse
import torch import torch
from mmcv.runner import save_checkpoint from mmcv.runner import save_checkpoint
from torch import nn as nn
from mmdet3d.apis import init_model from mmdet3d.apis import init_model
from torch import nn as nn
def fuse_conv_bn(conv, bn): def fuse_conv_bn(conv, bn):
......
...@@ -3,7 +3,6 @@ import argparse ...@@ -3,7 +3,6 @@ import argparse
import mmcv import mmcv
from mmcv import Config from mmcv import Config
from mmdet3d.datasets import build_dataset from mmdet3d.datasets import build_dataset
......
...@@ -5,7 +5,6 @@ import tempfile ...@@ -5,7 +5,6 @@ import tempfile
import torch import torch
from mmcv import Config from mmcv import Config
from mmcv.runner import load_state_dict from mmcv.runner import load_state_dict
from mmdet3d.models import build_detector from mmdet3d.models import build_detector
...@@ -129,13 +128,13 @@ def main(): ...@@ -129,13 +128,13 @@ def main():
EXTRACT_KEYS = { EXTRACT_KEYS = {
'rpn_head.conv_pred.conv_cls.weight': 'rpn_head.conv_pred.conv_cls.weight':
('rpn_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]), ('rpn_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]),
'rpn_head.conv_pred.conv_cls.bias': 'rpn_head.conv_pred.conv_cls.bias':
('rpn_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]), ('rpn_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]),
'rpn_head.conv_pred.conv_reg.weight': 'rpn_head.conv_pred.conv_reg.weight':
('rpn_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]), ('rpn_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]),
'rpn_head.conv_pred.conv_reg.bias': 'rpn_head.conv_pred.conv_reg.bias':
('rpn_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)]) ('rpn_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)])
} }
# Delete some useless keys # Delete some useless keys
......
...@@ -5,7 +5,6 @@ import tempfile ...@@ -5,7 +5,6 @@ import tempfile
import torch import torch
from mmcv import Config from mmcv import Config
from mmcv.runner import load_state_dict from mmcv.runner import load_state_dict
from mmdet3d.models import build_detector from mmdet3d.models import build_detector
...@@ -105,13 +104,13 @@ def main(): ...@@ -105,13 +104,13 @@ def main():
EXTRACT_KEYS = { EXTRACT_KEYS = {
'bbox_head.conv_pred.conv_cls.weight': 'bbox_head.conv_pred.conv_cls.weight':
('bbox_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]), ('bbox_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]),
'bbox_head.conv_pred.conv_cls.bias': 'bbox_head.conv_pred.conv_cls.bias':
('bbox_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]), ('bbox_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]),
'bbox_head.conv_pred.conv_reg.weight': 'bbox_head.conv_pred.conv_reg.weight':
('bbox_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]), ('bbox_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]),
'bbox_head.conv_pred.conv_reg.bias': 'bbox_head.conv_pred.conv_reg.bias':
('bbox_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)]) ('bbox_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)])
} }
# Delete some useless keys # Delete some useless keys
......
...@@ -4,14 +4,13 @@ import os ...@@ -4,14 +4,13 @@ import os
import warnings import warnings
import mmcv import mmcv
import mmdet
import torch import torch
from mmcv import Config, DictAction from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model) wrap_fp16_model)
import mmdet
from mmdet3d.apis import single_gpu_test from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model from mmdet3d.models import build_model
...@@ -43,31 +42,31 @@ def parse_args(): ...@@ -43,31 +42,31 @@ def parse_args():
'--fuse-conv-bn', '--fuse-conv-bn',
action='store_true', action='store_true',
help='Whether to fuse conv and bn, this will slightly increase' help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed') 'the inference speed')
parser.add_argument( parser.add_argument(
'--gpu-ids', '--gpu-ids',
type=int, type=int,
nargs='+', nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use ' help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)') '(only applicable to non-distributed training)')
parser.add_argument( parser.add_argument(
'--gpu-id', '--gpu-id',
type=int, type=int,
default=0, default=0,
help='id of gpu to use ' help='id of gpu to use '
'(only applicable to non-distributed testing)') '(only applicable to non-distributed testing)')
parser.add_argument( parser.add_argument(
'--format-only', '--format-only',
action='store_true', action='store_true',
help='Format the output results without perform evaluation. It is' help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and ' 'useful when you want to format the result to a specific format and '
'submit it to the test server') 'submit it to the test server')
parser.add_argument( parser.add_argument(
'--eval', '--eval',
type=str, type=str,
nargs='+', nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., "bbox",' help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
parser.add_argument('--show', action='store_true', help='show results') parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument( parser.add_argument(
'--show-dir', help='directory where results will be saved') '--show-dir', help='directory where results will be saved')
...@@ -78,7 +77,7 @@ def parse_args(): ...@@ -78,7 +77,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--tmpdir', '--tmpdir',
help='tmp directory used for collecting results from multiple ' help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified') 'workers, available when gpu-collect is not specified')
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument( parser.add_argument(
'--deterministic', '--deterministic',
...@@ -89,24 +88,24 @@ def parse_args(): ...@@ -89,24 +88,24 @@ def parse_args():
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to ' 'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
parser.add_argument( parser.add_argument(
'--options', '--options',
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy ' help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function (deprecate), ' 'format will be kwargs for dataset.evaluate() function (deprecate), '
'change to --eval-options instead.') 'change to --eval-options instead.')
parser.add_argument( parser.add_argument(
'--eval-options', '--eval-options',
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='custom options for evaluation, the key-value pair in xxx=yyy ' help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function') 'format will be kwargs for dataset.evaluate() function')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
...@@ -131,7 +130,7 @@ def main(): ...@@ -131,7 +130,7 @@ def main():
args = parse_args() args = parse_args()
assert args.out or args.eval or args.format_only or args.show \ assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \ or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the ' ('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"' 'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir"') ', "--format-only", "--show" or "--show-dir"')
...@@ -248,8 +247,8 @@ def main(): ...@@ -248,8 +247,8 @@ def main():
eval_kwargs = cfg.get('evaluation', {}).copy() eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args # hard-code way to remove EvalHook args
for key in [ for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule' 'rule'
]: ]:
eval_kwargs.pop(key, None) eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.eval, **kwargs)) eval_kwargs.update(dict(metric=args.eval, **kwargs))
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division from __future__ import division
import argparse import argparse
import copy import copy
import os import os
...@@ -12,7 +13,6 @@ import torch ...@@ -12,7 +13,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv import Config, DictAction from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist from mmcv.runner import get_dist_info, init_dist
from mmdet import __version__ as mmdet_version from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import init_random_seed, train_model from mmdet3d.apis import init_random_seed, train_model
...@@ -49,19 +49,19 @@ def parse_args(): ...@@ -49,19 +49,19 @@ def parse_args():
'--gpus', '--gpus',
type=int, type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use ' help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)') '(only applicable to non-distributed training)')
group_gpus.add_argument( group_gpus.add_argument(
'--gpu-ids', '--gpu-ids',
type=int, type=int,
nargs='+', nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use ' help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)') '(only applicable to non-distributed training)')
group_gpus.add_argument( group_gpus.add_argument(
'--gpu-id', '--gpu-id',
type=int, type=int,
default=0, default=0,
help='number of gpus to use ' help='number of gpus to use '
'(only applicable to non-distributed training)') '(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument( parser.add_argument(
'--diff-seed', '--diff-seed',
...@@ -76,18 +76,18 @@ def parse_args(): ...@@ -76,18 +76,18 @@ def parse_args():
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), ' 'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.') 'change to --cfg-options instead.')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to ' 'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
......
...@@ -4,7 +4,6 @@ from os import path as osp ...@@ -4,7 +4,6 @@ from os import path as osp
import mmcv import mmcv
import numpy as np import numpy as np
from mmdet3d.core.bbox import limit_period from mmdet3d.core.bbox import limit_period
...@@ -61,7 +60,6 @@ def update_outdoor_dbinfos(root_dir, out_dir, pkl_files): ...@@ -61,7 +60,6 @@ def update_outdoor_dbinfos(root_dir, out_dir, pkl_files):
def update_nuscenes_or_lyft_infos(root_dir, out_dir, pkl_files): def update_nuscenes_or_lyft_infos(root_dir, out_dir, pkl_files):
print(f'{pkl_files} will be modified because ' print(f'{pkl_files} will be modified because '
f'of the refactor of the LIDAR coordinate system.') f'of the refactor of the LIDAR coordinate system.')
if root_dir == out_dir: if root_dir == out_dir:
...@@ -89,7 +87,7 @@ def update_nuscenes_or_lyft_infos(root_dir, out_dir, pkl_files): ...@@ -89,7 +87,7 @@ def update_nuscenes_or_lyft_infos(root_dir, out_dir, pkl_files):
parser = argparse.ArgumentParser(description='Arg parser for data coords ' parser = argparse.ArgumentParser(description='Arg parser for data coords '
'update due to coords sys refactor.') 'update due to coords sys refactor.')
parser.add_argument('dataset', metavar='kitti', help='name of the dataset') parser.add_argument('dataset', metavar='kitti', help='name of the dataset')
parser.add_argument( parser.add_argument(
'--root-dir', '--root-dir',
......
...@@ -3,14 +3,16 @@ ...@@ -3,14 +3,16 @@
This folder contains the implementation of the InternImage for image classification. This folder contains the implementation of the InternImage for image classification.
<!-- TOC --> <!-- TOC -->
* [Install](#install)
* [Data Preparation](#data-preparation) - [Install](#install)
* [Evaluation](#evaluation) - [Data Preparation](#data-preparation)
* [Training from Scratch on ImageNet-1K](#training-from-scratch-on-imagenet-1k) - [Evaluation](#evaluation)
* [Manage Jobs with Slurm](#manage-jobs-with-slurm) - [Training from Scratch on ImageNet-1K](#training-from-scratch-on-imagenet-1k)
* [Training with Deepspeed](#training-with-deepspeed) - [Manage Jobs with Slurm](#manage-jobs-with-slurm)
* [Extracting Intermediate Features](#extracting-intermediate-features) - [Training with Deepspeed](#training-with-deepspeed)
* [Export](#export) - [Extracting Intermediate Features](#extracting-intermediate-features)
- [Export](#export)
<!-- TOC --> <!-- TOC -->
## Usage ## Usage
...@@ -36,6 +38,7 @@ conda activate internimage ...@@ -36,6 +38,7 @@ conda activate internimage
- Install `PyTorch>=1.10.0` and `torchvision>=0.9.0` with `CUDA>=10.2`: - Install `PyTorch>=1.10.0` and `torchvision>=0.9.0` with `CUDA>=10.2`:
For examples, to install torch==1.11 with CUDA==11.3: For examples, to install torch==1.11 with CUDA==11.3:
```bash ```bash
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
``` ```
...@@ -55,20 +58,24 @@ pip install opencv-python termcolor yacs pyyaml scipy ...@@ -55,20 +58,24 @@ pip install opencv-python termcolor yacs pyyaml scipy
``` ```
- Compiling CUDA operators - Compiling CUDA operators
```bash ```bash
cd ./ops_dcnv3 cd ./ops_dcnv3
sh ./make.sh sh ./make.sh
# unit test (should see all checking is True) # unit test (should see all checking is True)
python test.py python test.py
``` ```
- You can also install the operator using .whl files - You can also install the operator using .whl files
[DCNv3-1.0-whl](https://github.com/OpenGVLab/InternImage/releases/tag/whl_files) [DCNv3-1.0-whl](https://github.com/OpenGVLab/InternImage/releases/tag/whl_files)
### Data Preparation ### Data Preparation
We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
load data: load data:
- For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like: - For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
```bash ```bash
$ tree data $ tree data
imagenet imagenet
...@@ -90,13 +97,15 @@ load data: ...@@ -90,13 +97,15 @@ load data:
│ ├── img6.jpeg │ ├── img6.jpeg
│ └── ... │ └── ...
└── ... └── ...
``` ```
- To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes - To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
four files: four files:
- `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
- `train.txt`, `val.txt`: which store the relative path in the corresponding zip file and ground truth - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
label. Make sure the data folder looks like this: - `train.txt`, `val.txt`: which store the relative path in the corresponding zip file and ground truth
label. Make sure the data folder looks like this:
```bash ```bash
$ tree data $ tree data
...@@ -106,14 +115,14 @@ load data: ...@@ -106,14 +115,14 @@ load data:
├── train.zip ├── train.zip
├── val_map.txt ├── val_map.txt
└── val.zip └── val.zip
$ head -n 5 meta_data/val.txt $ head -n 5 meta_data/val.txt
ILSVRC2012_val_00000001.JPEG 65 ILSVRC2012_val_00000001.JPEG 65
ILSVRC2012_val_00000002.JPEG 970 ILSVRC2012_val_00000002.JPEG 970
ILSVRC2012_val_00000003.JPEG 230 ILSVRC2012_val_00000003.JPEG 230
ILSVRC2012_val_00000004.JPEG 809 ILSVRC2012_val_00000004.JPEG 809
ILSVRC2012_val_00000005.JPEG 516 ILSVRC2012_val_00000005.JPEG 516
$ head -n 5 meta_data/train.txt $ head -n 5 meta_data/train.txt
n01440764/n01440764_10026.JPEG 0 n01440764/n01440764_10026.JPEG 0
n01440764/n01440764_10027.JPEG 0 n01440764/n01440764_10027.JPEG 0
...@@ -121,6 +130,7 @@ load data: ...@@ -121,6 +130,7 @@ load data:
n01440764/n01440764_10040.JPEG 0 n01440764/n01440764_10040.JPEG 0
n01440764/n01440764_10042.JPEG 0 n01440764/n01440764_10042.JPEG 0
``` ```
- For ImageNet-22K dataset, make a folder named `fall11_whole` and move all images to labeled sub-folders in this - For ImageNet-22K dataset, make a folder named `fall11_whole` and move all images to labeled sub-folders in this
folder. Then download the train-val split folder. Then download the train-val split
file ([ILSVRC2011fall_whole_map_train.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_train.txt) file ([ILSVRC2011fall_whole_map_train.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_train.txt)
...@@ -144,7 +154,7 @@ To evaluate a pretrained `InternImage` on ImageNet val, run: ...@@ -144,7 +154,7 @@ To evaluate a pretrained `InternImage` on ImageNet val, run:
```bash ```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \ python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \
--cfg <config-file> --resume <checkpoint> --data-path <imagenet-path> --cfg <config-file> --resume <checkpoint> --data-path <imagenet-path>
``` ```
For example, to evaluate the `InternImage-B` with a single GPU: For example, to evaluate the `InternImage-B` with a single GPU:
...@@ -161,7 +171,7 @@ python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.p ...@@ -161,7 +171,7 @@ python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.p
To train an `InternImage` on ImageNet from scratch, run: To train an `InternImage` on ImageNet from scratch, run:
```bash ```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py \ python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py \
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>] --cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
``` ```
...@@ -187,13 +197,13 @@ GPUS=8 sh train_in1k.sh <partition> <job-name> configs/internimage_s_1k_224.yaml ...@@ -187,13 +197,13 @@ GPUS=8 sh train_in1k.sh <partition> <job-name> configs/internimage_s_1k_224.yaml
GPUS=8 sh train_in1k.sh <partition> <job-name> configs/internimage_xl_22kto1k_384.pth --resume internimage_xl_22kto1k_384.pth --eval GPUS=8 sh train_in1k.sh <partition> <job-name> configs/internimage_xl_22kto1k_384.pth --resume internimage_xl_22kto1k_384.pth --eval
``` ```
<!-- <!--
### Test pretrained model on ImageNet-22K ### Test pretrained model on ImageNet-22K
For example, to evaluate the `InternImage-L-22k`: For example, to evaluate the `InternImage-L-22k`:
```bash ```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py \ python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py \
--cfg configs/internimage_xl_22k_192to384.yaml --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory>] \ --cfg configs/internimage_xl_22k_192to384.yaml --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory>] \
--resume internimage_xl_22k_192to384.pth --eval --resume internimage_xl_22k_192to384.pth --eval
``` --> ``` -->
...@@ -220,15 +230,15 @@ pip install deepspeed==0.8.3 ...@@ -220,15 +230,15 @@ pip install deepspeed==0.8.3
Then you could launch the training in a slurm system with 8 GPUs as follows (tiny and huge as examples). Then you could launch the training in a slurm system with 8 GPUs as follows (tiny and huge as examples).
The default zero stage is 1 and it could config via command line args `--zero-stage`. The default zero stage is 1 and it could config via command line args `--zero-stage`.
``` ```
GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_t_1k_224.yaml --batch-size 128 --accumulation-steps 4 GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_t_1k_224.yaml --batch-size 128 --accumulation-steps 4
GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_t_1k_224.yaml --batch-size 128 --accumulation-steps 4 --eval --resume ckpt.pth GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_t_1k_224.yaml --batch-size 128 --accumulation-steps 4 --eval --resume ckpt.pth
GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_t_1k_224.yaml --batch-size 128 --accumulation-steps 4 --eval --resume deepspeed_ckpt_dir GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_t_1k_224.yaml --batch-size 128 --accumulation-steps 4 --eval --resume deepspeed_ckpt_dir
GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_h_22kto1k_640.yaml --batch-size 16 --accumulation-steps 4 --pretrained ckpt/internimage_h_jointto22k_384.pth GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_h_22kto1k_640.yaml --batch-size 16 --accumulation-steps 4 --pretrained ckpt/internimage_h_jointto22k_384.pth
GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_h_22kto1k_640.yaml --batch-size 16 --accumulation-steps 4 --pretrained ckpt/internimage_h_jointto22k_384.pth --zero-stage 3 GPUS=8 GPUS_PER_NODE=8 sh train_in1k_deepspeed.sh vc_research_4 train configs/internimage_h_22kto1k_640.yaml --batch-size 16 --accumulation-steps 4 --pretrained ckpt/internimage_h_jointto22k_384.pth --zero-stage 3
``` ```
🤗 **Huggingface Accelerate Integration of Deepspeed** 🤗 **Huggingface Accelerate Integration of Deepspeed**
Optionally, you could use our [Huggingface accelerate](https://github.com/huggingface/accelerate) integration to use deepspeed. Optionally, you could use our [Huggingface accelerate](https://github.com/huggingface/accelerate) integration to use deepspeed.
...@@ -250,12 +260,12 @@ Here is the reference GPU memory cost for InternImage-H with 8 GPUs. ...@@ -250,12 +260,12 @@ Here is the reference GPU memory cost for InternImage-H with 8 GPUs.
- total batch size = 512, 16 batch size for each GPU, gradient accumulation steps = 4. - total batch size = 512, 16 batch size for each GPU, gradient accumulation steps = 4.
| Resolution | Deepspeed | Cpu offloading | Memory | | Resolution | Deepspeed | Cpu offloading | Memory |
| --- | --- | --- | --- | | ---------- | --------- | -------------- | ------ |
| 640 | zero1 | False | 22572 | | 640 | zero1 | False | 22572 |
| 640 | zero3 | False | 20000 | | 640 | zero3 | False | 20000 |
| 640 | zero3 | True | 19144 | | 640 | zero3 | True | 19144 |
| 384 | zero1 | False | 16000 | | 384 | zero1 | False | 16000 |
| 384 | zero3 | True | 11928 | | 384 | zero3 | True | 11928 |
**Convert Checkpoints** **Convert Checkpoints**
...@@ -272,7 +282,7 @@ Then, you could use `best.pth` as usual, e.g., `model.load_state_dict(torch.load ...@@ -272,7 +282,7 @@ Then, you could use `best.pth` as usual, e.g., `model.load_state_dict(torch.load
### Extracting Intermediate Features ### Extracting Intermediate Features
To extract the features of an intermediate layer, you could use `extract_feature.py`. To extract the features of an intermediate layer, you could use `extract_feature.py`.
For example, extract features of `b.png` from layers `patch_embed` and `levels.0.downsample` and save them to 'b.pth'. For example, extract features of `b.png` from layers `patch_embed` and `levels.0.downsample` and save them to 'b.pth'.
...@@ -280,16 +290,16 @@ For example, extract features of `b.png` from layers `patch_embed` and `levels.0 ...@@ -280,16 +290,16 @@ For example, extract features of `b.png` from layers `patch_embed` and `levels.0
python extract_feature.py --cfg configs/internimage_t_1k_224.yaml --img b.png --keys patch_embed levels.0.downsample --save --resume internimage_t_1k_224.pth python extract_feature.py --cfg configs/internimage_t_1k_224.yaml --img b.png --keys patch_embed levels.0.downsample --save --resume internimage_t_1k_224.pth
``` ```
### Export ### Export
To export `InternImage-T` from PyTorch to ONNX, run: To export `InternImage-T` from PyTorch to ONNX, run:
```shell ```shell
python export.py --model_name internimage_t_1k_224 --ckpt_dir /path/to/ckpt/dir --onnx python export.py --model_name internimage_t_1k_224 --ckpt_dir /path/to/ckpt/dir --onnx
``` ```
To export `InternImage-T` from PyTorch to TensorRT, run: To export `InternImage-T` from PyTorch to TensorRT, run:
```shell ```shell
python export.py --model_name internimage_t_1k_224 --ckpt_dir /path/to/ckpt/dir --trt python export.py --model_name internimage_t_1k_224 --ckpt_dir /path/to/ckpt/dir --trt
``` ```
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
import os import os
import yaml import yaml
from yacs.config import CfgNode as CN from yacs.config import CfgNode as CN
...@@ -82,7 +83,6 @@ _C.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE = False ...@@ -82,7 +83,6 @@ _C.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE = False
_C.MODEL.INTERN_IMAGE.REMOVE_CENTER = False _C.MODEL.INTERN_IMAGE.REMOVE_CENTER = False
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Training settings # Training settings
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
......
...@@ -17,4 +17,4 @@ ...@@ -17,4 +17,4 @@
"steps_per_print": "inf", "steps_per_print": "inf",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto" "train_micro_batch_size_per_gpu": "auto"
} }
\ No newline at end of file
...@@ -18,4 +18,4 @@ ...@@ -18,4 +18,4 @@
"steps_per_print": "inf", "steps_per_print": "inf",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto" "train_micro_batch_size_per_gpu": "auto"
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment