Commit f8ec5a0a authored by luopl's avatar luopl
Browse files

Update ms_deform_attn.py

parent ff69c17f
...@@ -14,12 +14,13 @@ from torch.utils.cpp_extension import load ...@@ -14,12 +14,13 @@ from torch.utils.cpp_extension import load
_C = None _C = None
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
_C = load( pass
"MultiScaleDeformableAttention", # _C = load(
sources=[f"{os.path.dirname(__file__)}/ops/cuda/ms_deform_attn_cuda.cu"], # "MultiScaleDeformableAttention",
extra_cflags=["-O2"], # sources=[f"{os.path.dirname(__file__)}/ops/cuda/ms_deform_attn_cuda.cu"],
verbose=True, # extra_cflags=["-O2"],
) # verbose=True,
# )
except Exception as e: except Exception as e:
warnings.warn(f"Failed to load MultiScaleDeformableAttention C++ extension: {e}") warnings.warn(f"Failed to load MultiScaleDeformableAttention C++ extension: {e}")
else: else:
...@@ -355,20 +356,22 @@ class MultiScaleDeformableAttention(nn.Module): ...@@ -355,20 +356,22 @@ class MultiScaleDeformableAttention(nn.Module):
) )
# the original impl for fp32 training # the original impl for fp32 training
if _C is not None and value.is_cuda: # if _C is not None and value.is_cuda:
output = MultiScaleDeformableAttnFunction.apply( # output = MultiScaleDeformableAttnFunction.apply(
value.to(torch.float32), # value.to(torch.float32),
spatial_shapes, # spatial_shapes,
level_start_index, # level_start_index,
sampling_locations, # sampling_locations,
attention_weights, # attention_weights,
self.im2col_step, # self.im2col_step,
) # )
else: # else:
output = multi_scale_deformable_attn_pytorch( # output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights # value, spatial_shapes, sampling_locations, attention_weights
) # )
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
if value.dtype != torch.float32: if value.dtype != torch.float32:
output = output.to(value.dtype) output = output.to(value.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment