Unverified Commit d1283981 authored by wang shiguang's avatar wang shiguang Committed by GitHub
Browse files

Parrots wrapper (#316)



* add parrots and pytorch dataloader wrapper

* fix commit
Co-authored-by: default avatarwangshiguang <wangshiguang@sensetime.com>
parent 1cb3e36a
...@@ -6,7 +6,8 @@ from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of, ...@@ -6,7 +6,8 @@ from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
requires_executable, requires_package, slice_list, requires_executable, requires_package, slice_list,
tuple_cast) tuple_cast)
from .parrots_wrapper import (CUDA_HOME, BuildExtension, CppExtension, from .parrots_wrapper import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension, SyncBatchNorm, _AdaptiveAvgPoolNd, CUDAExtension, DataLoader, PoolDataLoader,
SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm,
_ConvNd, _ConvTransposeMixin, _InstanceNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
_MaxPoolNd, get_build_config) _MaxPoolNd, get_build_config)
...@@ -28,5 +29,5 @@ __all__ = [ ...@@ -28,5 +29,5 @@ __all__ = [
'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', 'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin', '_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension', '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension' 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader'
] ]
...@@ -27,6 +27,15 @@ def _get_conv(): ...@@ -27,6 +27,15 @@ def _get_conv():
return _ConvNd, _ConvTransposeMixin return _ConvNd, _ConvTransposeMixin
def _get_dataloader():
if torch.__version__ == 'parrots':
from torch.utils.data import DataLoader, PoolDataLoader
else:
from torch.utils.data import DataLoader
PoolDataLoader = DataLoader
return DataLoader, PoolDataLoader
def _get_extension(): def _get_extension():
if torch.__version__ == 'parrots': if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension from parrots.utils.build_extension import BuildExtension, Extension
...@@ -63,6 +72,7 @@ def _get_norm(): ...@@ -63,6 +72,7 @@ def _get_norm():
CUDA_HOME = _get_cuda_home() CUDA_HOME = _get_cuda_home()
_ConvNd, _ConvTransposeMixin = _get_conv() _ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension() BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment