Unverified Commit 9185eee8 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove runner, parallel, engine and device (#2216)

* Remove runner, parallel, engine and device

* fix format

* remove outdated docs
parent 19a02415
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import numpy as np
import popart
import poptorch
import torch
import torch.nn as nn
from mmcv.utils import Registry
def _options_assigner(cfg, options_node):
# set popart.options by config
# cfg: dict, python data type
# options_node: python module or function
if isinstance(cfg, dict):
for key in cfg:
_options_assigner(cfg[key], getattr(options_node, key))
elif isinstance(cfg, (int, float, str, list)):
if callable(options_node):
options_node(cfg)
else:
error_msg = f'options_node type {type(options_node)} not supported'
raise NotImplementedError(error_msg)
else:
error_msg = f'cfg type {type(cfg)} not supported'
raise NotImplementedError(error_msg)
def cfg2options(cfg):
"""Parse dictionary to ipu options.
Args:
cfg (dict): A dictionary of ipu settings.
Returns:
dict[str, poptorch.Options]: Training options and inference options
of IPU.
"""
# set ipu options for inference and training by config
train_cfg = cfg.pop('train_cfg', {})
eval_cfg = cfg.pop('eval_cfg', {})
eval_cfg['replicationFactor'] = 1 # eval mode only use one replica
eval_cfg['executionStrategy'] = 'ShardedExecution'
# overwrite default ipu cfg with specified train cfgs
training_ipu_cfg = {**cfg, **train_cfg}
# overwrite default ipu cfg with specified eval cfgs
inference_ipu_cfg = {**cfg, **eval_cfg}
ipu_options = {
'training': _cast_to_options(training_ipu_cfg),
'inference': _cast_to_options(inference_ipu_cfg)
}
# TODO configure these codes
ipu_options['training']._Popart.set('disableGradAccumulationTensorStreams',
True)
ipu_options['training']._Popart.set(
'accumulateOuterFragmentSettings.schedule',
int(popart.AccumulateOuterFragmentSchedule.OverlapMemoryOptimized))
ipu_options['training'].Precision.enableStochasticRounding(True)
return ipu_options
def _cast_to_options(cfg):
# If it cannot be directly assigned, use if statement to parse it,
# and if it can be directly assigned, use _options_assigner to assign
options = poptorch.Options()
if 'availableMemoryProportion' in cfg:
available_memory_proportion = cfg.pop('availableMemoryProportion')
mem_props = {}
for i, mem_prop in enumerate(available_memory_proportion):
mem_props[f'IPU{i}'] = mem_prop
options.setAvailableMemoryProportion(mem_props)
if 'executionStrategy' in cfg:
execution_strategy = cfg.pop('executionStrategy')
if execution_strategy == 'SameAsIpu':
options.setExecutionStrategy(
poptorch.PipelinedExecution(
getattr(poptorch.AutoStage, execution_strategy)))
elif execution_strategy == 'ShardedExecution':
options.setExecutionStrategy(poptorch.ShardedExecution())
else:
raise NotImplementedError(
'executionStrategy should be "SameAsIpu" or "ShardedExecution"'
f', but got {execution_strategy}')
if 'partialsType' in cfg:
partials_type = cfg.pop('partialsType')
options.Precision.setPartialsType(getattr(
torch, partials_type)) # half or float
_options_assigner(cfg, options)
return options
def model_sharding(model, split_edges):
"""split models in-place into multi-IPUs.
Args:
model (nn.Module): The target model to be split.
split_edges (list of dict): Model layer names or layer numbers
of split edge. Each item of ``split_edges`` is a dictionary,
which may contain the following key-pairs:
- layer_to_call: PyTorch module to assign to the block
- user_id (optional): A user defined identifier for the block.
- ipu_id: The id of the IPU to run on.
Examples:
>>> split_edges = [
... dict(layer_to_call='model.conv1', ipu_id=0),
... dict(layer_to_call='model.conv3', ipu_id=1)]
>>> sharding_model = model_sharding(torch_model, split_edges)
Returns:
nn.Module: Split model.
"""
if len(split_edges) == 0:
return model
assert isinstance(split_edges, list)
spilt_edges_dict = {edge['layer_to_call']: edge for edge in split_edges}
for idx, (name, module) in enumerate(model.named_modules()):
if idx in spilt_edges_dict and name in spilt_edges_dict:
raise ValueError(
'The same layer is referenced twice while doing model'
f' partition: idx is {idx} and name is {name}')
edge = spilt_edges_dict.pop(name, None)
edge = spilt_edges_dict.pop(idx, edge)
if edge is not None:
poptorch.BeginBlock(module, edge.get('user_id', name),
edge['ipu_id'])
# ensure all split_edges are used
if len(spilt_edges_dict) > 0:
split_edge_names = list(spilt_edges_dict.keys())
raise RuntimeError(
f'split_edges: {split_edge_names} are not contained in the model')
return model
def recomputation_checkpoint(model: nn.Module, module_names: list):
"""Annotates the output of a module to be checkpointed instead of
recomputed.
If recomputation mode is enabled, ipu will release the activations of
the middle layers to save memory. During the backward of gradient,
the activation of the middle layer will be recalculated again.
This function is used to declare the activations of some intermediate
layers that need to be saved in order to skip the recomputation of
some layers.
Args:
model (nn.Module): The target model to apply recomputation
checkpoint.
module_names (list): Layer names of module.
"""
def recompute_outputs(module, inputs, outputs):
if isinstance(outputs, tuple):
return tuple(poptorch.recomputationCheckpoint(y) for y in outputs)
else:
return poptorch.recomputationCheckpoint(outputs)
for name, module in model.named_modules():
if name in module_names:
module.register_forward_hook(recompute_outputs)
module_names.remove(name)
# check all module_names are used
assert len(module_names) == 0,\
f'recomputed nodes: {module_names} are not contained in the model'
def compare_ndarray(featA, featB, rtol=1e-3, atol=1e-5):
"""Align data between two activations or weights."""
try:
np.testing.assert_allclose(featA, featB, rtol=rtol, atol=atol)
except AssertionError as e:
print(e)
def build_from_cfg_with_wrapper(cfg,
registry,
wrapper_func=None,
default_args=None):
"""Build a module from config dict and wrap module with "wrapper_func".
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
wrapper_func (function): Used to wrap class
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if wrapper_func is None:
wrapped_obj_cls = obj_cls
else:
wrapped_obj_cls = wrapper_func(obj_cls)
try:
return wrapped_obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{wrapped_obj_cls.__name__}: {e}')
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MLUDataParallel
from .distributed import MLUDistributedDataParallel
__all__ = ['MLUDataParallel', 'MLUDistributedDataParallel']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
"""scatter copies tensor to MLU directly."""
if isinstance(input, list):
outputs = [scatter(_input, devices) for _input in input]
return outputs
elif isinstance(input, torch.Tensor):
output = input.contiguous()
return output.to('mlu') if devices != [-1] else output
else:
raise Exception(f'Unknown type {type(input)}.')
class Scatter:
@staticmethod
def forward(target_mlus, input):
outputs = scatter(input, target_mlus)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel
from .scatter_gather import scatter_kwargs
class MLUDataParallel(MMDataParallel):
"""The MLUDataParallel module that supports DataContainer.
MLUDataParallel is a class inherited from MMDataParall, which supports
MLU training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MLU, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
.. warning::
MLUDataParallel only supports single MLU training, if you need to
train with multiple MLUs, please use MLUDistributedDataParallel
instead. If you have multiple MLUs, you can set the environment
variable ``MLU_VISIBLE_DEVICES=0`` (or any other card number(s))
to specify the running device.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mlu:0')
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.parallel import MMDistributedDataParallel
from .scatter_gather import scatter_kwargs
class MLUDistributedDataParallel(MMDistributedDataParallel):
"""The DDP module supports DataContainer.
MLUDDP has one difference from MMDDP which moves data to MLU with coping
instead of scattering.
"""
def to_kwargs(self, inputs, kwargs, device_id):
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel.data_container import DataContainer
from ._functions import Scatter
def scatter(inputs, target_mlus, dim=0):
"""Scatter inputs to target mlu.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_mlus != [-1]:
obj = obj.to('mlu')
return [obj]
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_mlus, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_mlus, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for targets in target_mlus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
def scatter_kwargs(inputs, kwargs, target_mlus, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_mlus, dim) if inputs else []
kwargs = scatter(kwargs, target_mlus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MPSDataParallel
__all__ = ['MPSDataParallel']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel
from ..scatter_gather import scatter_kwargs
class MPSDataParallel(MMDataParallel):
"""The MPSDataParallel module that supports DataContainer.
MPSDataParallel is a class inherited from MMDataParall, which supports
MPS training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MPS, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mps:0')
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel.data_container import DataContainer
from mmcv.utils import deprecated_api_warning
from ._functions import Scatter
from .utils import get_device
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter(inputs, target_devices, dim=0):
"""Scatter inputs to target devices.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
current_device = get_device()
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_devices != [-1]:
obj = obj.to(current_device)
return [obj]
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_devices, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_devices, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for _ in target_devices]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_devices, dim) if inputs else []
kwargs = scatter(kwargs, target_devices, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | mps | cpu.
"""
if IS_CUDA_AVAILABLE:
return 'cuda'
elif IS_MLU_AVAILABLE:
return 'mlu'
elif IS_MPS_AVAILABLE:
return 'mps'
else:
return 'cpu'
# Copyright (c) OpenMMLab. All rights reserved.
from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
single_gpu_test)
__all__ = [
'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
'single_gpu_test'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import pickle
import shutil
import tempfile
import time
from typing import Optional
import mmengine
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import DataLoader
import mmcv
from mmcv.runner import get_dist_info
def single_gpu_test(model: nn.Module, data_loader: DataLoader) -> list:
"""Test model with a single gpu.
This method tests model with a single gpu and displays test progress bar.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for data in data_loader:
with torch.no_grad():
result = model(return_loss=False, **data)
results.extend(result)
# Assume result has the same length of batch_size
# refer to https://github.com/open-mmlab/mmcv/issues/985
batch_size = len(result)
for _ in range(batch_size):
prog_bar.update()
return results
def multi_gpu_test(model: nn.Module,
data_loader: DataLoader,
tmpdir: Optional[str] = None,
gpu_collect: bool = False) -> Optional[list]:
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting
``gpu_collect=True``, it encodes results to gpu tensors and use gpu
communication for results collection. On cpu mode it saves the results on
different gpus to ``tmpdir`` and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
time.sleep(2) # This line can prevent deadlock problem in some cases.
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, **data)
results.extend(result)
if rank == 0:
batch_size = len(result)
batch_size_all = batch_size * world_size
if batch_size_all + prog_bar.completed > len(dataset):
batch_size_all = len(dataset) - prog_bar.completed
for _ in range(batch_size_all):
prog_bar.update()
# collect results from all ranks
if gpu_collect:
result_from_ranks = collect_results_gpu(results, len(dataset))
else:
result_from_ranks = collect_results_cpu(results, len(dataset), tmpdir)
return result_from_ranks
def collect_results_cpu(result_part: list,
size: int,
tmpdir: Optional[str] = None) -> Optional[list]:
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
``tmpdir`` and collect them by the rank 0 worker.
Args:
result_part (list): Result list containing result parts
to be collected.
size (int): Size of the results, commonly equal to length of
the results.
tmpdir (str | None): temporal directory for collected results to
store. If set to None, it will create a random temporal directory
for it.
Returns:
list: The collected results.
"""
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
mmcv.mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
part_file = osp.join(tmpdir, f'part_{rank}.pkl') # type: ignore
mmengine.dump(result_part, part_file)
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore
part_result = mmengine.load(part_file)
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if part_result:
part_list.append(part_result)
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir) # type: ignore
return ordered_results
def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
"""Collect results under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu
communication for results collection.
Args:
result_part (list): Result list containing result parts
to be collected.
size (int): Size of the results, commonly equal to length of
the results.
Returns:
list: The collected results.
"""
rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if part_result:
part_list.append(part_result)
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
else:
return None
{
"resnet50_caffe": "detectron/resnet50_caffe",
"resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
"resnet101_caffe": "detectron/resnet101_caffe",
"resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
}
{
"vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
"vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
"vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
"vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
"vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
"vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
"vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
"vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
"resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth",
"resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth",
"resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
"resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
"resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth",
"resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth",
"resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth",
"resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth",
"resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
"resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
"resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
"resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
"se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
"se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
"resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
"resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
"resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
"resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
"shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
"shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
"mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth",
"mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth",
"mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth",
"repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth",
"repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth",
"repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth",
"repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth",
"repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth",
"repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth",
"repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth",
"repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth",
"repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth",
"repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth",
"repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth",
"repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth",
"res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth",
"res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth",
"res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth",
"swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
"swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
"swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth",
"swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth",
"t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth",
"t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth",
"t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth",
"tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth",
"vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth",
"vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth",
"vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth"
}
{
"vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
"detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
"detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
"detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
"detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
"detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
"resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
"resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
"resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
"contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
"detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
"detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
"jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
"jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
"jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
"jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
"jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
"jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
"msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
"msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
"msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
"msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
"msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
"bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
"kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
"kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
"res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
"regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
"regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
"regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
"regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
"regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
"regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
"regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
"regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
"resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
"resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
"resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
"mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
"mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
"mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
"contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
"contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
"resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
"resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
"resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
"darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
"mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
}
{
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": null,
"shufflenetv2_x2.0": null,
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
}
......@@ -4,7 +4,6 @@ import torch
from torch import Tensor
from torch import nn as nn
from mmcv.runner import force_fp32
from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)
......@@ -91,7 +90,6 @@ class PointsSampler(nn.Module):
self.samplers.append(get_sampler_cls(fps_mod)())
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
"""
Args:
......@@ -102,6 +100,11 @@ class PointsSampler(nn.Module):
Returns:
torch.Tensor: (B, npoint, sample_num) Indices of sampled points.
"""
if points_xyz.dtype == torch.half:
points_xyz = points_xyz.to(torch.float32)
if features.dtype == torch.half:
features = features.to(torch.float32)
indices = []
last_fps_end_index = 0
for fps_sample_range, sampler, npoint in zip(
......
# Copyright (c) OpenMMLab. All rights reserved.
from .collate import collate
from .data_container import DataContainer
from .data_parallel import MMDataParallel
from .distributed import MMDistributedDataParallel
from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter, scatter_kwargs
from .utils import is_module_wrapper
__all__ = [
'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from torch import Tensor
from torch.nn.parallel._functions import _get_stream
def scatter(input: Union[List, Tensor],
devices: List,
streams: Optional[List] = None) -> Union[List, Tensor]:
"""Scatters tensor across multiple GPUs."""
if streams is None:
streams = [None] * len(devices)
if isinstance(input, list):
chunk_size = (len(input) - 1) // len(devices) + 1
outputs = [
scatter(input[i], [devices[i // chunk_size]],
[streams[i // chunk_size]]) for i in range(len(input))
]
return outputs
elif isinstance(input, Tensor):
output = input.contiguous()
# TODO: copy to a pinned buffer first (if copying from CPU)
stream = streams[0] if output.numel() > 0 else None
if devices != [-1]:
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
return output
else:
raise Exception(f'Unknown type {type(input)}.')
def synchronize_stream(output: Union[List, Tensor], devices: List,
streams: List) -> None:
if isinstance(output, list):
chunk_size = len(output) // len(devices)
for i in range(len(devices)):
for j in range(chunk_size):
synchronize_stream(output[i * chunk_size + j], [devices[i]],
[streams[i]])
elif isinstance(output, Tensor):
if output.numel() != 0:
with torch.cuda.device(devices[0]):
main_stream = torch.cuda.current_stream()
main_stream.wait_stream(streams[0])
output.record_stream(main_stream)
else:
raise Exception(f'Unknown type {type(output)}.')
def get_input_device(input: Union[List, Tensor]) -> int:
if isinstance(input, list):
for item in input:
input_device = get_input_device(item)
if input_device != -1:
return input_device
return -1
elif isinstance(input, Tensor):
return input.get_device() if input.is_cuda else -1
else:
raise Exception(f'Unknown type {type(input)}.')
class Scatter:
@staticmethod
def forward(target_gpus: List[int], input: Union[List, Tensor]) -> tuple:
input_device = get_input_device(input)
streams = None
if input_device == -1 and target_gpus != [-1]:
# Perform CPU to GPU copies in a background stream
streams = [_get_stream(device) for device in target_gpus]
outputs = scatter(input, target_gpus, streams)
# Synchronize with the copy stream
if streams is not None:
synchronize_stream(outputs, target_gpus, streams)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Mapping, Sequence
import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate
from .data_container import DataContainer
def collate(batch: Sequence, samples_per_gpu: int = 1):
"""Puts each data field into a tensor/DataContainer with outer dimension
batch size.
Extend default_collate to add support for
:type:`~mmcv.parallel.DataContainer`. There are 3 cases.
1. cpu_only = True, e.g., meta data
2. cpu_only = False, stack = True, e.g., images tensors
3. cpu_only = False, stack = False, e.g., gt bboxes
"""
if not isinstance(batch, Sequence):
raise TypeError(f'{batch.dtype} is not supported.')
if isinstance(batch[0], DataContainer):
stacked = []
if batch[0].cpu_only:
for i in range(0, len(batch), samples_per_gpu):
stacked.append(
[sample.data for sample in batch[i:i + samples_per_gpu]])
return DataContainer(
stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
elif batch[0].stack:
for i in range(0, len(batch), samples_per_gpu):
assert isinstance(batch[i].data, torch.Tensor)
if batch[i].pad_dims is not None:
ndim = batch[i].dim()
assert ndim > batch[i].pad_dims
max_shape = [0 for _ in range(batch[i].pad_dims)]
for dim in range(1, batch[i].pad_dims + 1):
max_shape[dim - 1] = batch[i].size(-dim)
for sample in batch[i:i + samples_per_gpu]:
for dim in range(0, ndim - batch[i].pad_dims):
assert batch[i].size(dim) == sample.size(dim)
for dim in range(1, batch[i].pad_dims + 1):
max_shape[dim - 1] = max(max_shape[dim - 1],
sample.size(-dim))
padded_samples = []
for sample in batch[i:i + samples_per_gpu]:
pad = [0 for _ in range(batch[i].pad_dims * 2)]
for dim in range(1, batch[i].pad_dims + 1):
pad[2 * dim -
1] = max_shape[dim - 1] - sample.size(-dim)
padded_samples.append(
F.pad(
sample.data, pad, value=sample.padding_value))
stacked.append(default_collate(padded_samples))
elif batch[i].pad_dims is None:
stacked.append(
default_collate([
sample.data
for sample in batch[i:i + samples_per_gpu]
]))
else:
raise ValueError(
'pad_dims should be either None or integers (1-3)')
else:
for i in range(0, len(batch), samples_per_gpu):
stacked.append(
[sample.data for sample in batch[i:i + samples_per_gpu]])
return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
elif isinstance(batch[0], Sequence):
transposed = zip(*batch)
return [collate(samples, samples_per_gpu) for samples in transposed]
elif isinstance(batch[0], Mapping):
return {
key: collate([d[key] for d in batch], samples_per_gpu)
for key in batch[0]
}
else:
return default_collate(batch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment