Commit f8ec5a0a authored by luopl's avatar luopl
Browse files

Update ms_deform_attn.py

parent ff69c17f
......@@ -14,12 +14,13 @@ from torch.utils.cpp_extension import load
_C = None
if torch.cuda.is_available():
try:
_C = load(
"MultiScaleDeformableAttention",
sources=[f"{os.path.dirname(__file__)}/ops/cuda/ms_deform_attn_cuda.cu"],
extra_cflags=["-O2"],
verbose=True,
)
pass
# _C = load(
# "MultiScaleDeformableAttention",
# sources=[f"{os.path.dirname(__file__)}/ops/cuda/ms_deform_attn_cuda.cu"],
# extra_cflags=["-O2"],
# verbose=True,
# )
except Exception as e:
warnings.warn(f"Failed to load MultiScaleDeformableAttention C++ extension: {e}")
else:
......@@ -355,20 +356,22 @@ class MultiScaleDeformableAttention(nn.Module):
)
# the original impl for fp32 training
if _C is not None and value.is_cuda:
output = MultiScaleDeformableAttnFunction.apply(
value.to(torch.float32),
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
# if _C is not None and value.is_cuda:
# output = MultiScaleDeformableAttnFunction.apply(
# value.to(torch.float32),
# spatial_shapes,
# level_start_index,
# sampling_locations,
# attention_weights,
# self.im2col_step,
# )
# else:
# output = multi_scale_deformable_attn_pytorch(
# value, spatial_shapes, sampling_locations, attention_weights
# )
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
if value.dtype != torch.float32:
output = output.to(value.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment