import torch from torch.autograd.function import Function, once_differentiable from . import deformable_aggregation_ext class DeformableAggregationFunction(Function): @staticmethod def forward( ctx, mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, ): # output: [bs, num_pts, num_embeds] mc_ms_feat = mc_ms_feat.contiguous().float() spatial_shape = spatial_shape.contiguous().int() scale_start_index = scale_start_index.contiguous().int() sampling_location = sampling_location.contiguous().float() weights = weights.contiguous().float() output = deformable_aggregation_ext.deformable_aggregation_forward( mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, ) ctx.save_for_backward( mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, ) return output @staticmethod @once_differentiable def backward(ctx, grad_output): ( mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, ) = ctx.saved_tensors mc_ms_feat = mc_ms_feat.contiguous().float() spatial_shape = spatial_shape.contiguous().int() scale_start_index = scale_start_index.contiguous().int() sampling_location = sampling_location.contiguous().float() weights = weights.contiguous().float() grad_mc_ms_feat = torch.zeros_like(mc_ms_feat) grad_sampling_location = torch.zeros_like(sampling_location) grad_weights = torch.zeros_like(weights) deformable_aggregation_ext.deformable_aggregation_backward( mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, grad_output.contiguous(), grad_mc_ms_feat, grad_sampling_location, grad_weights, ) return ( grad_mc_ms_feat, None, None, grad_sampling_location, grad_weights, )