Unverified Commit a805cd9d authored by wang shiguang's avatar wang shiguang Committed by GitHub
Browse files

wrapper extension (#279)



* wrapper extension

* fix
Co-authored-by: default avatarwangshiguang <wangshiguang@sensetime.com>
parent 693f61af
...@@ -5,7 +5,8 @@ from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of, ...@@ -5,7 +5,8 @@ from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
is_str, is_tuple_of, iter_cast, list_cast, is_str, is_tuple_of, iter_cast, list_cast,
requires_executable, requires_package, slice_list, requires_executable, requires_package, slice_list,
tuple_cast) tuple_cast)
from .parrots_wrapper import (CUDA_HOME, SyncBatchNorm, _AdaptiveAvgPoolNd, from .parrots_wrapper import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm,
_ConvNd, _ConvTransposeMixin, _InstanceNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
_MaxPoolNd, get_build_config) _MaxPoolNd, get_build_config)
...@@ -26,5 +27,6 @@ __all__ = [ ...@@ -26,5 +27,6 @@ __all__ = [
'Registry', 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'Registry', 'build_from_cfg', 'Timer', 'TimerError', 'check_time',
'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', 'CUDA_HOME', 'SyncBatchNorm', '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd',
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin', '_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
'_InstanceNorm', '_MaxPoolNd', 'get_build_config' '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension'
] ]
from functools import partial
import torch import torch
...@@ -25,6 +27,17 @@ def _get_conv(): ...@@ -25,6 +27,17 @@ def _get_conv():
return _ConvNd, _ConvTransposeMixin return _ConvNd, _ConvTransposeMixin
def _get_extension():
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
CppExtension = partial(Extension, cuda=False)
CUDAExtension = partial(Extension, cuda=True)
else:
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
return BuildExtension, CppExtension, CUDAExtension
def _get_pool(): def _get_pool():
if torch.__version__ == 'parrots': if torch.__version__ == 'parrots':
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
...@@ -50,6 +63,7 @@ def _get_norm(): ...@@ -50,6 +63,7 @@ def _get_norm():
CUDA_HOME = _get_cuda_home() CUDA_HOME = _get_cuda_home()
_ConvNd, _ConvTransposeMixin = _get_conv() _ConvNd, _ConvTransposeMixin = _get_conv()
BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment