Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
a0939977
Commit
a0939977
authored
Nov 16, 2022
by
ZShaopeng
Committed by
Zaida Zhou
Nov 23, 2022
Browse files
[Feature] Support MultiScaleDeformableAttn with cambricon MLU backend
parent
193de43b
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1393 additions
and
43 deletions
+1393
-43
docs/en/understand_mmcv/ops.md
docs/en/understand_mmcv/ops.md
+1
-1
docs/zh_cn/understand_mmcv/ops.md
docs/zh_cn/understand_mmcv/ops.md
+1
-1
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
+33
-0
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
+853
-0
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
+420
-0
mmcv/ops/multi_scale_deform_attn.py
mmcv/ops/multi_scale_deform_attn.py
+5
-3
tests/test_ops/test_ms_deformable_attn.py
tests/test_ops/test_ms_deformable_attn.py
+80
-38
No files found.
docs/en/understand_mmcv/ops.md
View file @
a0939977
...
@@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
...
@@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
| MergeCells | | √ | | |
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ |
| |
| MultiScaleDeformableAttn | | √ |
√
| |
| NMS | √ | √ | √ | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
| NMSQuadri | √ | √ | | |
...
...
docs/zh_cn/understand_mmcv/ops.md
View file @
a0939977
...
@@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
...
@@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MergeCells | | √ | | |
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ |
| |
| MultiScaleDeformableAttn | | √ |
√
| |
| NMS | √ | √ | √ | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
| NMSQuadri | √ | √ | | |
...
...
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
View file @
a0939977
...
@@ -362,4 +362,37 @@ __mlu_func__ inline void convertFloat2half(half *dst, float *src,
...
@@ -362,4 +362,37 @@ __mlu_func__ inline void convertFloat2half(half *dst, float *src,
#endif
#endif
}
}
/*!
* @brief recursiveSumPool.
* @param[in,out] dst
* Pointer to NRAM that stores the input and output data.
* @param[in] low_dim
* Which is the number of low dim.
* @param[in] high_dim
* Which is the number of high dim.
* @param[in] kernel_limit
* Which is the high_dim of sumpool per time.
******************************************************************************/
template
<
typename
T
>
__mlu_func__
void
recursiveSumPool
(
T
*
dst
,
int
low_dim
,
int
high_dim
,
int
kernel_limit
)
{
for
(;
high_dim
>
1
;)
{
int
repeat_s
=
high_dim
/
kernel_limit
;
int
remain_s
=
high_dim
%
kernel_limit
;
if
(
remain_s
)
{
__bang_sumpool
((
T
*
)
dst
,
(
T
*
)
dst
,
low_dim
,
1
,
remain_s
,
1
,
remain_s
,
1
,
1
);
}
if
(
repeat_s
)
{
__bang_sumpool
((
T
*
)
dst
+
(
remain_s
>
0
?
low_dim
:
0
),
(
T
*
)
dst
+
remain_s
*
low_dim
,
low_dim
,
kernel_limit
*
repeat_s
,
1
,
kernel_limit
,
1
,
1
,
kernel_limit
);
}
high_dim
=
repeat_s
+
(
bool
)
remain_s
;
}
return
;
}
#endif // COMMON_MLU_HELPER_HPP_
#endif // COMMON_MLU_HELPER_HPP_
mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
0 → 100644
View file @
a0939977
This diff is collapsed.
Click to expand it.
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
0 → 100644
View file @
a0939977
This diff is collapsed.
Click to expand it.
mmcv/ops/multi_scale_deform_attn.py
View file @
a0939977
...
@@ -12,6 +12,7 @@ from mmengine.registry import MODELS
...
@@ -12,6 +12,7 @@ from mmengine.registry import MODELS
from
mmengine.utils
import
deprecated_api_warning
from
mmengine.utils
import
deprecated_api_warning
from
torch.autograd.function
import
Function
,
once_differentiable
from
torch.autograd.function
import
Function
,
once_differentiable
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
from
..utils
import
ext_loader
from
..utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
ext_module
=
ext_loader
.
load_ext
(
...
@@ -26,7 +27,7 @@ class MultiScaleDeformableAttnFunction(Function):
...
@@ -26,7 +27,7 @@ class MultiScaleDeformableAttnFunction(Function):
sampling_locations
:
torch
.
Tensor
,
sampling_locations
:
torch
.
Tensor
,
attention_weights
:
torch
.
Tensor
,
attention_weights
:
torch
.
Tensor
,
im2col_step
:
torch
.
Tensor
)
->
torch
.
Tensor
:
im2col_step
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""GPU version of multi-scale deformable attention.
"""GPU
/MLU
version of multi-scale deformable attention.
Args:
Args:
value (torch.Tensor): The value has shape
value (torch.Tensor): The value has shape
...
@@ -63,7 +64,7 @@ class MultiScaleDeformableAttnFunction(Function):
...
@@ -63,7 +64,7 @@ class MultiScaleDeformableAttnFunction(Function):
@
staticmethod
@
staticmethod
@
once_differentiable
@
once_differentiable
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
tuple
:
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
tuple
:
"""GPU version of backward function.
"""GPU
/MLU
version of backward function.
Args:
Args:
grad_output (torch.Tensor): Gradient of output tensor of forward.
grad_output (torch.Tensor): Gradient of output tensor of forward.
...
@@ -346,7 +347,8 @@ class MultiScaleDeformableAttention(BaseModule):
...
@@ -346,7 +347,8 @@ class MultiScaleDeformableAttention(BaseModule):
raise
ValueError
(
raise
ValueError
(
f
'Last dim of reference_points must be'
f
'Last dim of reference_points must be'
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
f
' 2 or 4, but get
{
reference_points
.
shape
[
-
1
]
}
instead.'
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
if
((
IS_CUDA_AVAILABLE
and
value
.
is_cuda
)
or
(
IS_MLU_AVAILABLE
and
value
.
is_mlu
)):
output
=
MultiScaleDeformableAttnFunction
.
apply
(
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
)
attention_weights
,
self
.
im2col_step
)
...
...
tests/test_ops/test_ms_deformable_attn.py
View file @
a0939977
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
mmcv.ops.multi_scale_deform_attn
import
(
from
mmcv.ops.multi_scale_deform_attn
import
(
MultiScaleDeformableAttention
,
MultiScaleDeformableAttnFunction
,
MultiScaleDeformableAttention
,
MultiScaleDeformableAttnFunction
,
multi_scale_deformable_attn_pytorch
)
multi_scale_deformable_attn_pytorch
)
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
_USING_PARROTS
=
True
_USING_PARROTS
=
True
try
:
try
:
...
@@ -14,22 +15,25 @@ except ImportError:
...
@@ -14,22 +15,25 @@ except ImportError:
_USING_PARROTS
=
False
_USING_PARROTS
=
False
@
pytest
.
mark
.
parametrize
(
'device
_type
'
,
[
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cpu'
,
'cpu'
,
pytest
.
param
(
pytest
.
param
(
'cuda:0'
,
'cuda:0'
,
marks
=
pytest
.
mark
.
skipif
(
marks
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
))
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
])
def
test_multiscale_deformable_attention
(
device_type
):
def
test_multiscale_deformable_attention
(
device
):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
# embed_dims must be divisible by num_heads,
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention
(
MultiScaleDeformableAttention
(
embed_dims
=
256
,
embed_dims
=
256
,
num_heads
=
7
,
num_heads
=
7
,
)
)
device
=
torch
.
device
(
device
_type
)
device
=
torch
.
device
(
device
)
msda
=
MultiScaleDeformableAttention
(
msda
=
MultiScaleDeformableAttention
(
embed_dims
=
3
,
num_levels
=
2
,
num_heads
=
3
)
embed_dims
=
3
,
num_levels
=
2
,
num_heads
=
3
)
msda
.
init_weights
()
msda
.
init_weights
()
...
@@ -70,20 +74,19 @@ def test_forward_multi_scale_deformable_attn_pytorch():
...
@@ -70,20 +74,19 @@ def test_forward_multi_scale_deformable_attn_pytorch():
attention_weights
.
double
()).
detach
()
attention_weights
.
double
()).
detach
()
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_forward_equal_with_pytorch_double
():
def
test_forward_equal_with_pytorch_double
():
N
,
M
,
D
=
1
,
2
,
2
N
,
M
,
D
=
1
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
shapes
=
torch
.
as_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
torch
.
long
)
.
cuda
()
shapes
=
torch
.
as_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
torch
.
long
)
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
torch
.
manual_seed
(
3
)
torch
.
manual_seed
(
3
)
value
=
torch
.
rand
(
N
,
S
,
M
,
D
)
.
cuda
()
*
0.01
value
=
torch
.
rand
(
N
,
S
,
M
,
D
)
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
)
.
cuda
()
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
)
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
)
.
cuda
()
+
1e-5
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
)
+
1e-5
attention_weights
/=
attention_weights
.
sum
(
attention_weights
/=
attention_weights
.
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
2
,
keepdim
=
True
)
-
2
,
keepdim
=
True
)
...
@@ -93,8 +96,9 @@ def test_forward_equal_with_pytorch_double():
...
@@ -93,8 +96,9 @@ def test_forward_equal_with_pytorch_double():
attention_weights
.
double
()).
detach
().
cpu
()
attention_weights
.
double
()).
detach
().
cpu
()
output_cuda
=
MultiScaleDeformableAttnFunction
.
apply
(
output_cuda
=
MultiScaleDeformableAttnFunction
.
apply
(
value
.
double
(),
shapes
,
level_start_index
,
sampling_locations
.
double
(),
value
.
cuda
().
double
(),
shapes
.
cuda
(),
level_start_index
.
cuda
(),
attention_weights
.
double
(),
im2col_step
).
detach
().
cpu
()
sampling_locations
.
cuda
().
double
(),
attention_weights
.
cuda
().
double
(),
im2col_step
).
detach
().
cpu
()
assert
torch
.
allclose
(
output_cuda
,
output_pytorch
)
assert
torch
.
allclose
(
output_cuda
,
output_pytorch
)
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
...
@@ -103,20 +107,28 @@ def test_forward_equal_with_pytorch_double():
...
@@ -103,20 +107,28 @@ def test_forward_equal_with_pytorch_double():
assert
max_rel_err
<
1e-15
assert
max_rel_err
<
1e-15
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
parametrize
(
'device'
,
[
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
pytest
.
param
(
def
test_forward_equal_with_pytorch_float
():
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
def
test_forward_equal_with_pytorch_float
(
device
):
N
,
M
,
D
=
1
,
2
,
2
N
,
M
,
D
=
1
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
shapes
=
torch
.
as_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
torch
.
long
)
.
cuda
()
shapes
=
torch
.
as_tensor
([(
6
,
4
),
(
3
,
2
)],
dtype
=
torch
.
long
)
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
torch
.
manual_seed
(
3
)
torch
.
manual_seed
(
3
)
value
=
torch
.
rand
(
N
,
S
,
M
,
D
)
.
cuda
()
*
0.01
value
=
torch
.
rand
(
N
,
S
,
M
,
D
)
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
)
.
cuda
()
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
)
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
)
.
cuda
()
+
1e-5
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
)
+
1e-5
attention_weights
/=
attention_weights
.
sum
(
attention_weights
/=
attention_weights
.
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
2
,
keepdim
=
True
)
-
2
,
keepdim
=
True
)
...
@@ -124,19 +136,37 @@ def test_forward_equal_with_pytorch_float():
...
@@ -124,19 +136,37 @@ def test_forward_equal_with_pytorch_float():
output_pytorch
=
multi_scale_deformable_attn_pytorch
(
output_pytorch
=
multi_scale_deformable_attn_pytorch
(
value
,
shapes
,
sampling_locations
,
attention_weights
).
detach
().
cpu
()
value
,
shapes
,
sampling_locations
,
attention_weights
).
detach
().
cpu
()
output_cuda
=
MultiScaleDeformableAttnFunction
.
apply
(
output_device
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
shapes
,
level_start_index
,
sampling_locations
,
value
.
to
(
device
),
shapes
.
to
(
device
),
level_start_index
.
to
(
device
),
attention_weights
,
im2col_step
).
detach
().
cpu
()
sampling_locations
.
to
(
device
),
attention_weights
.
to
(
device
),
assert
torch
.
allclose
(
output_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
im2col_step
).
detach
().
cpu
()
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
assert
torch
.
allclose
(
output_device
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
max_abs_err
=
(
output_device
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_device
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
output_pytorch
.
abs
()).
max
()
assert
max_abs_err
<
1e-9
assert
max_abs_err
<
1e-9
assert
max_rel_err
<
1e-6
assert
max_rel_err
<
1e-6
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
parametrize
(
'device'
,
[
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float
,
pytest
.
param
(
torch
.
double
,
marks
=
pytest
.
mark
.
skipif
(
IS_MLU_AVAILABLE
,
reason
=
'MLU does not support for 64-bit floating point'
)),
torch
.
half
])
@
pytest
.
mark
.
parametrize
(
'channels'
,
[
@
pytest
.
mark
.
parametrize
(
'channels'
,
[
4
,
4
,
30
,
30
,
...
@@ -146,20 +176,22 @@ def test_forward_equal_with_pytorch_float():
...
@@ -146,20 +176,22 @@ def test_forward_equal_with_pytorch_float():
1025
,
1025
,
])
])
def
test_gradient_numerical
(
channels
,
def
test_gradient_numerical
(
channels
,
device
,
dtype
,
grad_value
=
True
,
grad_value
=
True
,
grad_sampling_loc
=
True
,
grad_sampling_loc
=
True
,
grad_attn_weight
=
True
):
grad_attn_weight
=
True
):
N
,
M
,
_
=
1
,
2
,
2
N
,
M
,
_
=
1
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
Lq
,
L
,
P
=
2
,
2
,
2
shapes
=
torch
.
as_tensor
([(
3
,
2
),
(
2
,
1
)],
dtype
=
torch
.
long
).
cuda
(
)
shapes
=
torch
.
as_tensor
([(
3
,
2
),
(
2
,
1
)],
dtype
=
torch
.
long
).
to
(
device
)
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
(
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
(
1
,
)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
S
=
sum
((
H
*
W
).
item
()
for
H
,
W
in
shapes
)
value
=
torch
.
rand
(
N
,
S
,
M
,
channels
).
cuda
(
)
*
0.01
value
=
torch
.
rand
(
N
,
S
,
M
,
channels
).
to
(
device
)
*
0.01
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
).
cuda
(
)
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
).
to
(
device
)
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
).
cuda
(
)
+
1e-5
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
).
to
(
device
)
+
1e-5
attention_weights
/=
attention_weights
.
sum
(
attention_weights
/=
attention_weights
.
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
1
,
keepdim
=
True
).
sum
(
-
2
,
keepdim
=
True
)
-
2
,
keepdim
=
True
)
...
@@ -170,13 +202,23 @@ def test_gradient_numerical(channels,
...
@@ -170,13 +202,23 @@ def test_gradient_numerical(channels,
value
.
requires_grad
=
grad_value
value
.
requires_grad
=
grad_value
sampling_locations
.
requires_grad
=
grad_sampling_loc
sampling_locations
.
requires_grad
=
grad_sampling_loc
attention_weights
.
requires_grad
=
grad_attn_weight
attention_weights
.
requires_grad
=
grad_attn_weight
if
device
==
'cuda'
:
dtype
=
torch
.
double
eps
=
1e-6
elif
device
==
'mlu'
:
dtype
=
torch
.
float
eps
=
1e-4
if
_USING_PARROTS
:
if
_USING_PARROTS
:
assert
gradcheck
(
assert
gradcheck
(
func
,
(
value
.
double
(
),
shapes
,
level_start_index
,
func
,
(
value
.
to
(
dtype
),
shapes
,
level_start_index
,
sampling_locations
.
double
(
),
attention_weights
.
double
(
),
sampling_locations
.
to
(
dtype
),
attention_weights
.
to
(
dtype
),
im2col_step
),
im2col_step
),
no_grads
=
[
shapes
,
level_start_index
])
no_grads
=
[
shapes
,
level_start_index
],
eps
=
eps
)
else
:
else
:
assert
gradcheck
(
func
,
(
value
.
double
(),
shapes
,
level_start_index
,
assert
gradcheck
(
sampling_locations
.
double
(),
func
,
(
value
.
to
(
dtype
),
shapes
,
level_start_index
,
attention_weights
.
double
(),
im2col_step
))
sampling_locations
.
to
(
dtype
),
attention_weights
.
to
(
dtype
),
im2col_step
),
eps
=
eps
,
atol
=
1e-2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment