Unverified Commit d1283981 authored by wang shiguang's avatar wang shiguang Committed by GitHub
Browse files

Parrots wrapper (#316)



* add parrots and pytorch dataloader wrapper

* fix commit
Co-authored-by: default avatarwangshiguang <wangshiguang@sensetime.com>
parent 1cb3e36a
......@@ -6,7 +6,8 @@ from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
requires_executable, requires_package, slice_list,
tuple_cast)
from .parrots_wrapper import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension, SyncBatchNorm, _AdaptiveAvgPoolNd,
CUDAExtension, DataLoader, PoolDataLoader,
SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm,
_ConvNd, _ConvTransposeMixin, _InstanceNorm,
_MaxPoolNd, get_build_config)
......@@ -28,5 +29,5 @@ __all__ = [
'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension'
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader'
]
......@@ -27,6 +27,15 @@ def _get_conv():
return _ConvNd, _ConvTransposeMixin
def _get_dataloader():
if torch.__version__ == 'parrots':
from torch.utils.data import DataLoader, PoolDataLoader
else:
from torch.utils.data import DataLoader
PoolDataLoader = DataLoader
return DataLoader, PoolDataLoader
def _get_extension():
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
......@@ -63,6 +72,7 @@ def _get_norm():
CUDA_HOME = _get_cuda_home()
_ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment