Commit fd79c680 authored by Valentin Andrei's avatar Valentin Andrei Committed by Facebook GitHub Bot
Browse files

Enforce torch.float32 for ms_deform_attn when using AMP

Reviewed By: stephenyan1231

Differential Revision: D30225977

fbshipit-source-id: 479b96acc7f90a8ee2373ab44112e21086e9d1d2
parent cb985322
...@@ -16,12 +16,22 @@ import torch ...@@ -16,12 +16,22 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd
from detr import _C as MSDA from detr import _C as MSDA
class MSDeformAttnFunction(Function): class MSDeformAttnFunction(Function):
# The @custom_fwd and @custom_bwd decorators are used in this case to allow enabling of
# Automatic Mixed Precision when we do not have implementations of custom CUDA kernels for
# all the precision types.
#
# TODO: After implementing `ms_deform_attn` CUDA kernels for FP16, we can remove the
# custom_fwd and custom_bwd decorators
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward( output = MSDA.ms_deform_attn_forward(
...@@ -31,6 +41,7 @@ class MSDeformAttnFunction(Function): ...@@ -31,6 +41,7 @@ class MSDeformAttnFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
@custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \ grad_value, grad_sampling_loc, grad_attn_weight = \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment