Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Salience-DETR_pytorch
Commits
f8ec5a0a
Commit
f8ec5a0a
authored
Oct 29, 2024
by
luopl
Browse files
Update ms_deform_attn.py
parent
ff69c17f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
20 deletions
+23
-20
models/bricks/ms_deform_attn.py
models/bricks/ms_deform_attn.py
+23
-20
No files found.
models/bricks/ms_deform_attn.py
View file @
f8ec5a0a
...
...
@@ -14,12 +14,13 @@ from torch.utils.cpp_extension import load
_C
=
None
if
torch
.
cuda
.
is_available
():
try
:
_C
=
load
(
"MultiScaleDeformableAttention"
,
sources
=
[
f
"
{
os
.
path
.
dirname
(
__file__
)
}
/ops/cuda/ms_deform_attn_cuda.cu"
],
extra_cflags
=
[
"-O2"
],
verbose
=
True
,
)
pass
# _C = load(
# "MultiScaleDeformableAttention",
# sources=[f"{os.path.dirname(__file__)}/ops/cuda/ms_deform_attn_cuda.cu"],
# extra_cflags=["-O2"],
# verbose=True,
# )
except
Exception
as
e
:
warnings
.
warn
(
f
"Failed to load MultiScaleDeformableAttention C++ extension:
{
e
}
"
)
else
:
...
...
@@ -355,20 +356,22 @@ class MultiScaleDeformableAttention(nn.Module):
)
# the original impl for fp32 training
if
_C
is
not
None
and
value
.
is_cuda
:
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
.
to
(
torch
.
float32
),
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
,
)
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
# if _C is not None and value.is_cuda:
# output = MultiScaleDeformableAttnFunction.apply(
# value.to(torch.float32),
# spatial_shapes,
# level_start_index,
# sampling_locations,
# attention_weights,
# self.im2col_step,
# )
# else:
# output = multi_scale_deformable_attn_pytorch(
# value, spatial_shapes, sampling_locations, attention_weights
# )
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
if
value
.
dtype
!=
torch
.
float32
:
output
=
output
.
to
(
value
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment