Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
86a38aa3
Unverified
Commit
86a38aa3
authored
Jul 27, 2023
by
qipengh
Committed by
GitHub
Jul 27, 2023
Browse files
[Feature] Support modulated_deform_conv and deform_conv with cambricon MLU backend (#2823)
parent
987d34b0
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
255 additions
and
34 deletions
+255
-34
docs/en/understand_mmcv/ops.md
docs/en/understand_mmcv/ops.md
+2
-2
docs/zh_cn/understand_mmcv/ops.md
docs/zh_cn/understand_mmcv/ops.md
+2
-2
mmcv/ops/__init__.py
mmcv/ops/__init__.py
+7
-0
mmcv/ops/deform_conv.py
mmcv/ops/deform_conv.py
+61
-0
mmcv/ops/modulated_deform_conv.py
mmcv/ops/modulated_deform_conv.py
+66
-0
tests/test_ops/test_deform_conv.py
tests/test_ops/test_deform_conv.py
+73
-15
tests/test_ops/test_masked_conv2d.py
tests/test_ops/test_masked_conv2d.py
+4
-0
tests/test_ops/test_modulated_deform_conv.py
tests/test_ops/test_modulated_deform_conv.py
+40
-15
No files found.
docs/en/understand_mmcv/ops.md
View file @
86a38aa3
...
@@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
...
@@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
| ConvexIoU | | √ | | | |
| ConvexIoU | | √ | | | |
| CornerPool | | √ | | | |
| CornerPool | | √ | | | |
| Correlation | | √ | | | |
| Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ |
| | √ |
| Deformable Convolution v1/v2 | √ | √ |
√
| | √ |
| Deformable RoIPool | | √ | √ | | √ |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | √ | | |
| DiffIoURotated | | √ | √ | | |
| DynamicScatter | | √ | √ | | |
| DynamicScatter | | √ | √ | | |
...
@@ -32,7 +32,7 @@ We implement common ops used in detection, segmentation, etc.
...
@@ -32,7 +32,7 @@ We implement common ops used in detection, segmentation, etc.
| MaskedConv | | √ | √ | | √ |
| MaskedConv | | √ | √ | | √ |
| MergeCells | | √ | | | |
| MergeCells | | √ | | | |
| MinAreaPolygon | | √ | | | |
| MinAreaPolygon | | √ | | | |
| ModulatedDeformConv2d | √ | √ |
| | √ |
| ModulatedDeformConv2d | √ | √ |
√
| | √ |
| MultiScaleDeformableAttn | | √ | √ | | |
| MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ |
| NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | √ | | √ |
| NMSRotated | √ | √ | √ | | √ |
...
...
docs/zh_cn/understand_mmcv/ops.md
View file @
86a38aa3
...
@@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
...
@@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ConvexIoU | | √ | | | |
| ConvexIoU | | √ | | | |
| CornerPool | | √ | | | |
| CornerPool | | √ | | | |
| Correlation | | √ | | | |
| Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ |
| | √ |
| Deformable Convolution v1/v2 | √ | √ |
√
| | √ |
| Deformable RoIPool | | √ | √ | | √ |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | √ | | |
| DiffIoURotated | | √ | √ | | |
| DynamicScatter | | √ | √ | | |
| DynamicScatter | | √ | √ | | |
...
@@ -32,7 +32,7 @@ MMCV 提供了检测、分割等任务中常用的算子
...
@@ -32,7 +32,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MaskedConv | | √ | √ | | √ |
| MaskedConv | | √ | √ | | √ |
| MergeCells | | √ | | | |
| MergeCells | | √ | | | |
| MinAreaPolygon | | √ | | | |
| MinAreaPolygon | | √ | | | |
| ModulatedDeformConv2d | √ | √ |
| | √ |
| ModulatedDeformConv2d | √ | √ |
√
| | √ |
| MultiScaleDeformableAttn | | √ | √ | | |
| MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ |
| NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | √ | | √ |
| NMSRotated | √ | √ | √ | | √ |
...
...
mmcv/ops/__init__.py
View file @
86a38aa3
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.utils
import
IS_MLU_AVAILABLE
from
.active_rotated_filter
import
active_rotated_filter
from
.active_rotated_filter
import
active_rotated_filter
from
.assign_score_withk
import
assign_score_withk
from
.assign_score_withk
import
assign_score_withk
from
.ball_query
import
ball_query
from
.ball_query
import
ball_query
...
@@ -109,3 +110,9 @@ __all__ = [
...
@@ -109,3 +110,9 @@ __all__ = [
'PrRoIPool'
,
'prroi_pool'
,
'bias_act'
,
'filtered_lrelu'
,
'conv2d'
,
'PrRoIPool'
,
'prroi_pool'
,
'bias_act'
,
'filtered_lrelu'
,
'conv2d'
,
'conv_transpose2d'
,
'filter2d'
,
'upsample2d'
,
'BezierAlign'
,
'bezier_align'
'conv_transpose2d'
,
'filter2d'
,
'upsample2d'
,
'BezierAlign'
,
'bezier_align'
]
]
if
IS_MLU_AVAILABLE
:
from
.deform_conv
import
DeformConv2dPack_MLU
# noqa:F401
from
.modulated_deform_conv
import
\
ModulatedDeformConv2dPack_MLU
# noqa:F401
__all__
.
extend
([
'ModulatedDeformConv2dPack_MLU'
,
'DeformConv2dPack_MLU'
])
mmcv/ops/deform_conv.py
View file @
86a38aa3
...
@@ -12,6 +12,7 @@ from torch.autograd import Function
...
@@ -12,6 +12,7 @@ from torch.autograd import Function
from
torch.autograd.function
import
once_differentiable
from
torch.autograd.function
import
once_differentiable
from
torch.nn.modules.utils
import
_pair
,
_single
from
torch.nn.modules.utils
import
_pair
,
_single
from
mmcv.utils
import
IS_MLU_AVAILABLE
from
..utils
import
ext_loader
from
..utils
import
ext_loader
from
.modulated_deform_conv
import
ModulatedDeformConv2dFunction
from
.modulated_deform_conv
import
ModulatedDeformConv2dFunction
...
@@ -438,3 +439,63 @@ class DeformConv2dPack(DeformConv2d):
...
@@ -438,3 +439,63 @@ class DeformConv2dPack(DeformConv2d):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
error_msgs
)
if
IS_MLU_AVAILABLE
:
import
torchvision
from
mmengine.utils
import
digit_version
from
torchvision.ops
import
deform_conv2d
as
tv_deform_conv2d
@
MODELS
.
register_module
(
'DCN'
,
force
=
True
)
class
DeformConv2dPack_MLU
(
DeformConv2d
):
"""This class is the DCN implementation of the MLU device. The MLU
backend support of the operator has been implemented in torchvision.
The mmcv registration mechanism is used for multiplexing here. The
torchvision implementation of DCN is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
im2col_step (int): Number of samples processed by
im2col_cuda_kernel per call. It will work when ``batch_size``
> ``im2col_step``, but ``batch_size`` must be divisible by
``im2col_step``. Default: 32. `New in version 1.7.2.
Currently not supported on MLU devices.`
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
assert
digit_version
(
torchvision
.
__version__
)
>=
digit_version
(
'0.10.0a0'
),
'the version of torchvision should be >= 0.10.0'
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
deform_groups
*
2
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
_pair
(
self
.
stride
),
padding
=
_pair
(
self
.
padding
),
dilation
=
_pair
(
self
.
dilation
),
bias
=
True
)
self
.
init_offset
()
def
init_offset
(
self
):
self
.
conv_offset
.
weight
.
data
.
zero_
()
self
.
conv_offset
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# type: ignore
cur_im2col_step
=
min
(
self
.
im2col_step
,
x
.
size
(
0
))
assert
(
x
.
size
(
0
)
%
cur_im2col_step
)
==
0
,
'batch size must be divisible by im2col_step'
offset
=
self
.
conv_offset
(
x
)
x
=
x
.
type_as
(
offset
)
weight
=
self
.
weight
.
type_as
(
x
)
return
tv_deform_conv2d
(
x
,
offset
,
weight
,
None
,
self
.
stride
,
self
.
padding
,
self
.
dilation
)
mmcv/ops/modulated_deform_conv.py
View file @
86a38aa3
...
@@ -11,6 +11,7 @@ from torch.autograd import Function
...
@@ -11,6 +11,7 @@ from torch.autograd import Function
from
torch.autograd.function
import
once_differentiable
from
torch.autograd.function
import
once_differentiable
from
torch.nn.modules.utils
import
_pair
,
_single
from
torch.nn.modules.utils
import
_pair
,
_single
from
mmcv.utils
import
IS_MLU_AVAILABLE
from
..utils
import
ext_loader
from
..utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
ext_module
=
ext_loader
.
load_ext
(
...
@@ -358,3 +359,68 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
...
@@ -358,3 +359,68 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
error_msgs
)
if
IS_MLU_AVAILABLE
:
import
torchvision
from
mmengine.utils
import
digit_version
from
torchvision.ops
import
deform_conv2d
as
tv_deform_conv2d
@
MODELS
.
register_module
(
'DCNv2'
,
force
=
True
)
class
ModulatedDeformConv2dPack_MLU
(
ModulatedDeformConv2d
):
"""This class is the DCNv2 implementation of the MLU device.
The MLU backend support of the operator has been implemented
in torchvision. The mmcv registration mechanism is used for
multiplexing here. The torchvision implementation of DCNv2 is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
assert
digit_version
(
torchvision
.
__version__
)
>=
digit_version
(
'0.10.0a0'
),
'the version of torchvision should be >= 0.10.0'
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
deform_groups
*
3
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
bias
=
True
)
self
.
init_weights
()
def
init_weights
(
self
):
super
().
init_weights
()
if
hasattr
(
self
,
'conv_offset'
):
self
.
conv_offset
.
weight
.
data
.
zero_
()
self
.
conv_offset
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
out
=
self
.
conv_offset
(
x
)
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
offset
=
torch
.
cat
((
o1
,
o2
),
dim
=
1
)
mask
=
torch
.
sigmoid
(
mask
)
x
=
x
.
type_as
(
offset
)
weight
=
self
.
weight
.
type_as
(
x
)
mask
=
mask
.
type_as
(
x
)
return
tv_deform_conv2d
(
x
,
offset
,
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
mask
=
mask
)
tests/test_ops/test_deform_conv.py
View file @
86a38aa3
...
@@ -5,6 +5,11 @@ import torch
...
@@ -5,6 +5,11 @@ import torch
from
mmengine.utils
import
digit_version
from
mmengine.utils
import
digit_version
from
mmengine.utils.dl_utils
import
TORCH_VERSION
from
mmengine.utils.dl_utils
import
TORCH_VERSION
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
if
IS_MLU_AVAILABLE
:
torch
.
backends
.
cnnl
.
allow_tf32
=
False
try
:
try
:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
# would be imported and used; we should test if our modules support it.
...
@@ -45,6 +50,9 @@ class TestDeformconv:
...
@@ -45,6 +50,9 @@ class TestDeformconv:
im2col_step
=
2
):
im2col_step
=
2
):
if
not
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
if
not
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
pytest
.
skip
(
'test requires GPU'
)
pytest
.
skip
(
'test requires GPU'
)
if
device
==
'mlu'
:
from
mmcv.ops
import
DeformConv2dPack_MLU
as
DeformConv2dPack
else
:
from
mmcv.ops
import
DeformConv2dPack
from
mmcv.ops
import
DeformConv2dPack
c_in
=
1
c_in
=
1
c_out
=
1
c_out
=
1
...
@@ -69,6 +77,8 @@ class TestDeformconv:
...
@@ -69,6 +77,8 @@ class TestDeformconv:
torch
.
Tensor
(
deform_weight
).
reshape
(
1
,
1
,
2
,
2
))
torch
.
Tensor
(
deform_weight
).
reshape
(
1
,
1
,
2
,
2
))
if
device
==
'cuda'
:
if
device
==
'cuda'
:
model
.
cuda
()
model
.
cuda
()
elif
device
==
'mlu'
:
model
.
mlu
()
model
.
type
(
dtype
)
model
.
type
(
dtype
)
out
=
model
(
x
)
out
=
model
(
x
)
...
@@ -108,6 +118,7 @@ class TestDeformconv:
...
@@ -108,6 +118,7 @@ class TestDeformconv:
def
_test_amp_deformconv
(
self
,
def
_test_amp_deformconv
(
self
,
input_dtype
,
input_dtype
,
threshold
=
1e-3
,
threshold
=
1e-3
,
device
=
'cuda'
,
batch_size
=
10
,
batch_size
=
10
,
im2col_step
=
2
):
im2col_step
=
2
):
"""The function to test amp released on pytorch 1.6.0.
"""The function to test amp released on pytorch 1.6.0.
...
@@ -120,15 +131,18 @@ class TestDeformconv:
...
@@ -120,15 +131,18 @@ class TestDeformconv:
input_dtype: torch.float or torch.half.
input_dtype: torch.float or torch.half.
threshold: the same as above function.
threshold: the same as above function.
"""
"""
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
return
return
if
device
==
'mlu'
:
from
mmcv.ops
import
DeformConv2dPack_MLU
as
DeformConv2dPack
else
:
from
mmcv.ops
import
DeformConv2dPack
from
mmcv.ops
import
DeformConv2dPack
c_in
=
1
c_in
=
1
c_out
=
1
c_out
=
1
repeated_input
=
np
.
repeat
(
input
,
batch_size
,
axis
=
0
)
repeated_input
=
np
.
repeat
(
input
,
batch_size
,
axis
=
0
)
repeated_gt_out
=
np
.
repeat
(
gt_out
,
batch_size
,
axis
=
0
)
repeated_gt_out
=
np
.
repeat
(
gt_out
,
batch_size
,
axis
=
0
)
repeated_gt_x_grad
=
np
.
repeat
(
gt_x_grad
,
batch_size
,
axis
=
0
)
repeated_gt_x_grad
=
np
.
repeat
(
gt_x_grad
,
batch_size
,
axis
=
0
)
x
=
torch
.
Tensor
(
repeated_input
).
cuda
(
).
type
(
input_dtype
)
x
=
torch
.
Tensor
(
repeated_input
).
to
(
device
).
type
(
input_dtype
)
x
.
requires_grad
=
True
x
.
requires_grad
=
True
model
=
DeformConv2dPack
(
model
=
DeformConv2dPack
(
in_channels
=
c_in
,
in_channels
=
c_in
,
...
@@ -143,7 +157,10 @@ class TestDeformconv:
...
@@ -143,7 +157,10 @@ class TestDeformconv:
torch
.
Tensor
(
offset_bias
).
reshape
(
8
))
torch
.
Tensor
(
offset_bias
).
reshape
(
8
))
model
.
weight
.
data
=
torch
.
nn
.
Parameter
(
model
.
weight
.
data
=
torch
.
nn
.
Parameter
(
torch
.
Tensor
(
deform_weight
).
reshape
(
1
,
1
,
2
,
2
))
torch
.
Tensor
(
deform_weight
).
reshape
(
1
,
1
,
2
,
2
))
if
device
==
'cuda'
:
model
.
cuda
()
model
.
cuda
()
elif
device
==
'mlu'
:
model
.
mlu
()
out
=
model
(
x
)
out
=
model
(
x
)
out
.
backward
(
torch
.
ones_like
(
out
))
out
.
backward
(
torch
.
ones_like
(
out
))
...
@@ -177,24 +194,65 @@ class TestDeformconv:
...
@@ -177,24 +194,65 @@ class TestDeformconv:
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
model
=
DeformConv2d
(
3
,
4
,
3
,
groups
=
3
)
model
=
DeformConv2d
(
3
,
4
,
3
,
groups
=
3
)
def
test_deformconv
(
self
):
@
pytest
.
mark
.
parametrize
(
'device, threshold'
,
[
self
.
_test_deformconv
(
torch
.
double
,
device
=
'cpu'
)
(
'cpu'
,
1e-1
),
self
.
_test_deformconv
(
torch
.
float
,
device
=
'cpu'
,
threshold
=
1e-1
)
pytest
.
param
(
self
.
_test_deformconv
(
torch
.
double
)
'cuda'
,
self
.
_test_deformconv
(
torch
.
float
)
1e-3
,
self
.
_test_deformconv
(
torch
.
half
,
threshold
=
1e-1
)
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
1e-3
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
)),
])
def
test_deformconv_float
(
self
,
device
,
threshold
):
self
.
_test_deformconv
(
torch
.
float
,
device
=
device
,
threshold
=
threshold
)
# test batch_size < im2col_step
# test batch_size < im2col_step
self
.
_test_deformconv
(
torch
.
float
,
batch_size
=
1
,
im2col_step
=
2
)
self
.
_test_deformconv
(
torch
.
float
,
batch_size
=
1
,
im2col_step
=
2
,
device
=
device
)
# test bach_size % im2col_step != 0
# test bach_size % im2col_step != 0
with
pytest
.
raises
(
with
pytest
.
raises
(
AssertionError
,
AssertionError
,
match
=
'batch size must be divisible by im2col_step'
):
match
=
'batch size must be divisible by im2col_step'
):
self
.
_test_deformconv
(
torch
.
float
,
batch_size
=
10
,
im2col_step
=
3
)
self
.
_test_deformconv
(
torch
.
float
,
batch_size
=
10
,
im2col_step
=
3
,
device
=
device
)
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cpu'
,
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
)),
])
def
test_deformconv_double
(
self
,
device
):
self
.
_test_deformconv
(
torch
.
double
,
device
=
device
)
@
pytest
.
mark
.
parametrize
(
'device, threshold'
,
[
pytest
.
param
(
'cuda'
,
1e-1
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
1e-1
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
)),
])
def
test_deformconv_half
(
self
,
device
,
threshold
):
self
.
_test_deformconv
(
torch
.
half
,
device
=
device
,
threshold
=
threshold
)
# test amp when torch version >= '1.6.0', the type of
# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half
# input data for deformconv might be torch.float or torch.half
if
(
TORCH_VERSION
!=
'parrots'
if
(
TORCH_VERSION
!=
'parrots'
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.6.0'
)):
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.6.0'
)):
with
autocast
(
enabled
=
True
):
with
autocast
(
enabled
=
True
):
self
.
_test_amp_deformconv
(
torch
.
float
,
1e-1
)
self
.
_test_amp_deformconv
(
self
.
_test_amp_deformconv
(
torch
.
half
,
1e-1
)
torch
.
float
,
device
=
device
,
threshold
=
threshold
)
self
.
_test_amp_deformconv
(
torch
.
half
,
device
=
device
,
threshold
=
threshold
)
tests/test_ops/test_masked_conv2d.py
View file @
86a38aa3
...
@@ -5,6 +5,10 @@ import torch
...
@@ -5,6 +5,10 @@ import torch
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
if
IS_MLU_AVAILABLE
:
torch
.
backends
.
cnnl
.
allow_tf32
=
False
torch
.
backends
.
mlu
.
matmul
.
allow_tf32
=
False
class
TestMaskedConv2d
:
class
TestMaskedConv2d
:
...
...
tests/test_ops/test_modulated_deform_conv.py
View file @
86a38aa3
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
mmengine.utils
import
digit_version
from
mmengine.utils
import
digit_version
from
mmengine.utils.dl_utils
import
TORCH_VERSION
from
mmengine.utils.dl_utils
import
TORCH_VERSION
from
mmcv.utils
import
IS_CUDA_AVAILABLE
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
try
:
try
:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
...
@@ -44,7 +44,12 @@ class TestMdconv:
...
@@ -44,7 +44,12 @@ class TestMdconv:
def
_test_mdconv
(
self
,
dtype
=
torch
.
float
,
device
=
'cuda'
):
def
_test_mdconv
(
self
,
dtype
=
torch
.
float
,
device
=
'cuda'
):
if
not
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
if
not
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
pytest
.
skip
(
'test requires GPU'
)
pytest
.
skip
(
'test requires GPU'
)
if
device
==
'mlu'
:
from
mmcv.ops
import
\
ModulatedDeformConv2dPack_MLU
as
ModulatedDeformConv2dPack
else
:
from
mmcv.ops
import
ModulatedDeformConv2dPack
from
mmcv.ops
import
ModulatedDeformConv2dPack
input
=
torch
.
tensor
(
input_t
,
dtype
=
dtype
,
device
=
device
)
input
=
torch
.
tensor
(
input_t
,
dtype
=
dtype
,
device
=
device
)
input
.
requires_grad
=
True
input
.
requires_grad
=
True
...
@@ -55,10 +60,7 @@ class TestMdconv:
...
@@ -55,10 +60,7 @@ class TestMdconv:
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
deform_groups
=
1
,
deform_groups
=
1
,
bias
=
False
)
bias
=
False
).
to
(
device
)
if
device
==
'cuda'
:
dcn
.
cuda
()
dcn
.
weight
.
data
.
fill_
(
1.
)
dcn
.
weight
.
data
.
fill_
(
1.
)
dcn
.
type
(
dtype
)
dcn
.
type
(
dtype
)
...
@@ -75,7 +77,7 @@ class TestMdconv:
...
@@ -75,7 +77,7 @@ class TestMdconv:
assert
numpy
.
allclose
(
dcn
.
conv_offset
.
bias
.
grad
.
cpu
().
detach
().
numpy
(),
assert
numpy
.
allclose
(
dcn
.
conv_offset
.
bias
.
grad
.
cpu
().
detach
().
numpy
(),
dcn_offset_b_grad
,
1e-2
)
dcn_offset_b_grad
,
1e-2
)
def
_test_amp_mdconv
(
self
,
input_dtype
=
torch
.
float
):
def
_test_amp_mdconv
(
self
,
input_dtype
=
torch
.
float
,
device
=
'cuda'
):
"""The function to test amp released on pytorch 1.6.0.
"""The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half,
The type of input data might be torch.float or torch.half,
...
@@ -85,10 +87,15 @@ class TestMdconv:
...
@@ -85,10 +87,15 @@ class TestMdconv:
Args:
Args:
input_dtype: torch.float or torch.half.
input_dtype: torch.float or torch.half.
"""
"""
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
return
return
if
device
==
'mlu'
:
from
mmcv.ops
import
\
ModulatedDeformConv2dPack_MLU
as
ModulatedDeformConv2dPack
else
:
from
mmcv.ops
import
ModulatedDeformConv2dPack
from
mmcv.ops
import
ModulatedDeformConv2dPack
input
=
torch
.
tensor
(
input_t
).
cuda
().
type
(
input_dtype
)
input
=
torch
.
tensor
(
input_t
).
to
(
device
).
type
(
input_dtype
)
input
.
requires_grad
=
True
input
.
requires_grad
=
True
dcn
=
ModulatedDeformConv2dPack
(
dcn
=
ModulatedDeformConv2dPack
(
...
@@ -98,7 +105,7 @@ class TestMdconv:
...
@@ -98,7 +105,7 @@ class TestMdconv:
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
deform_groups
=
1
,
deform_groups
=
1
,
bias
=
False
).
cuda
(
)
bias
=
False
).
to
(
device
)
dcn
.
weight
.
data
.
fill_
(
1.
)
dcn
.
weight
.
data
.
fill_
(
1.
)
output
=
dcn
(
input
)
output
=
dcn
(
input
)
output
.
sum
().
backward
()
output
.
sum
().
backward
()
...
@@ -119,6 +126,10 @@ class TestMdconv:
...
@@ -119,6 +126,10 @@ class TestMdconv:
'cuda'
,
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
)),
])
])
def
test_mdconv_float
(
self
,
device
):
def
test_mdconv_float
(
self
,
device
):
self
.
_test_mdconv
(
dtype
=
torch
.
float
,
device
=
device
)
self
.
_test_mdconv
(
dtype
=
torch
.
float
,
device
=
device
)
...
@@ -129,16 +140,30 @@ class TestMdconv:
...
@@ -129,16 +140,30 @@ class TestMdconv:
'cuda'
,
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
)),
])
])
def
test_mdconv_double
(
self
,
device
):
def
test_mdconv_double
(
self
,
device
):
self
.
_test_mdconv
(
dtype
=
torch
.
double
,
device
=
device
)
self
.
_test_mdconv
(
dtype
=
torch
.
double
,
device
=
device
)
def
test_mdconv_half
(
self
):
@
pytest
.
mark
.
parametrize
(
'device'
,
[
self
.
_test_mdconv
(
torch
.
half
)
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
)),
])
def
test_mdconv_half
(
self
,
device
):
self
.
_test_mdconv
(
torch
.
half
,
device
=
device
)
# test amp when torch version >= '1.6.0', the type of
# test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half
# input data for mdconv might be torch.float or torch.half
if
(
TORCH_VERSION
!=
'parrots'
if
(
TORCH_VERSION
!=
'parrots'
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.6.0'
)):
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.6.0'
)):
with
autocast
(
enabled
=
True
):
with
autocast
(
enabled
=
True
):
self
.
_test_amp_mdconv
(
torch
.
float
)
self
.
_test_amp_mdconv
(
torch
.
float
,
device
=
device
)
self
.
_test_amp_mdconv
(
torch
.
half
)
self
.
_test_amp_mdconv
(
torch
.
half
,
device
=
device
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment