
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>

#include <THC/THCAtomics.cuh>

#include <iostream>
#include <stdlib.h>


__device__ float bilinear_sampling(
    const float *&bottom_data, const int &height, const int &width,
    const int &num_embeds, const float &h_im, const float &w_im,
    const int &base_ptr
) {
  const int h_low = floorf(h_im);
  const int w_low = floorf(w_im);
  const int h_high = h_low + 1;
  const int w_high = w_low + 1;

  const float lh = h_im - h_low;
  const float lw = w_im - w_low;
  const float hh = 1 - lh, hw = 1 - lw;

  const int w_stride = num_embeds;
  const int h_stride = width * w_stride;
  const int h_low_ptr_offset = h_low * h_stride;
  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
  const int w_low_ptr_offset = w_low * w_stride;
  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;

  float v1 = 0;
  if (h_low >= 0 && w_low >= 0) {
    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
    v1 = bottom_data[ptr1];
  }
  float v2 = 0;
  if (h_low >= 0 && w_high <= width - 1) {
    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
    v2 = bottom_data[ptr2];
  }
  float v3 = 0;
  if (h_high <= height - 1 && w_low >= 0) {
    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
    v3 = bottom_data[ptr3];
  }
  float v4 = 0;
  if (h_high <= height - 1 && w_high <= width - 1) {
    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
    v4 = bottom_data[ptr4];
  }

  const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  return val;
}

struct float2_t{
  float a;
  float b;
};

__forceinline__ __device__
float2_t warp_reduce_sum(float2_t val, int max = 32) {
  for (int offset = max; offset > 0; offset >>= 1) {
    val.a += __shfl_down(val.a, offset);
    val.b += __shfl_down(val.b, offset);
  }
  return val;
}

template <int blocksize>
__forceinline__ __device__
float2_t block_reduce_sum(float2_t val, float2_t* shared) {
  const int lid = threadIdx.x % 64;
  const int wid = threadIdx.x / 64;
  constexpr int share_size = blocksize / 64;

  val = warp_reduce_sum(val);
  
  if constexpr (blocksize == 64) return val;

  if (lid == 0 && wid < share_size) {
    shared[wid] = val;
  }

  __syncthreads();

  if (wid == 0 && lid < share_size) {
    val = shared[lid];
    val = warp_reduce_sum(val, share_size / 2);
  }

  return val;

}

template <int blocksize>
__device__ void bilinear_sampling_grad_sp(
    const float *&bottom_data, const float &weight,
    const int &height, const int &width,
    const int &num_embeds, const float &h_im, const float &w_im,
    const int &base_ptr,
    const float &grad_output,
    float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights,
    float2_t* s_data) {
  const int h_low = floorf(h_im);
  const int w_low = floorf(w_im);
  const int h_high = h_low + 1;
  const int w_high = w_low + 1;

  const float lh = h_im - h_low;
  const float lw = w_im - w_low;
  const float hh = 1 - lh, hw = 1 - lw;

  const int w_stride = num_embeds;
  const int h_stride = width * w_stride;
  const int h_low_ptr_offset = h_low * h_stride;
  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
  const int w_low_ptr_offset = w_low * w_stride;
  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;

  const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
  const float top_grad_mc_ms_feat = grad_output * weight;
  float grad_h_weight = 0, grad_w_weight = 0;


  const int valid1 = (h_low >= 0 && w_low >= 0);
  const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
  float v1 = valid1 ? bottom_data[ptr1] : 0.0f;
  if (valid1) {
#ifdef __gfx936__
    __builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
  } 

  const int valid2 = (h_low >= 0 && w_high <= width - 1);
  const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
  float v2 = valid2 ? bottom_data[ptr2] : 0.0f;
  if (valid2) {
#ifdef __gfx936__
    __builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
  } 

  const int valid3 = (h_high <= height - 1 && w_low >= 0);
  const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
  float v3 = valid3 ? bottom_data[ptr3] : 0.0f;
  if (valid3) {
#ifdef __gfx936__
    __builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
  } 

  const int valid4 = (h_high <= height - 1 && w_high <= width - 1);
  const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
  float v4 = valid4 ? bottom_data[ptr4] : 0.0f;
  if (valid4) {
#ifdef __gfx936__
    __builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
  } 

  grad_h_weight += (-hw * v1) + (-lw * v2) + ( hw * v3) + ( lw * v4);
  grad_w_weight += (-hh * v1) + ( hh * v2) + (-lh * v3) + ( lh * v4);
  
  float2_t spl;
  spl.a = width * grad_w_weight * top_grad_mc_ms_feat;
  spl.b = height * grad_h_weight * top_grad_mc_ms_feat;

  spl = block_reduce_sum<blocksize>(spl, s_data);

  const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  
  float wei = grad_output * val;

  for (int offset=16; offset>=1; offset >>= 1) {
    wei += __shfl_down(wei, offset);
  }

  #ifdef __gfx936__
    // __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);

    if (threadIdx.x % 32 == 0) {
      // __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, wei);
      *grad_weights += wei;
    }

    if (threadIdx.x ==0) {
      __builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, spl.a);
      __builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, spl.b);
    }
  #else
    atomicAdd(grad_weights, grad_output * val);
    atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
    atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
  #endif
}


__device__ void bilinear_sampling_grad(
    const float *&bottom_data, const float &weight,
    const int &height, const int &width,
    const int &num_embeds, const float &h_im, const float &w_im,
    const int &base_ptr,
    const float &grad_output,
    float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) {
  const int h_low = floorf(h_im);
  const int w_low = floorf(w_im);
  const int h_high = h_low + 1;
  const int w_high = w_low + 1;

  const float lh = h_im - h_low;
  const float lw = w_im - w_low;
  const float hh = 1 - lh, hw = 1 - lw;

  const int w_stride = num_embeds;
  const int h_stride = width * w_stride;
  const int h_low_ptr_offset = h_low * h_stride;
  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
  const int w_low_ptr_offset = w_low * w_stride;
  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;

  const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
  const float top_grad_mc_ms_feat = grad_output * weight;
  float grad_h_weight = 0, grad_w_weight = 0;

  float v1 = 0;
  if (h_low >= 0 && w_low >= 0) {
    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
    v1 = bottom_data[ptr1];
    grad_h_weight -= hw * v1;
    grad_w_weight -= hh * v1;
    // atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
    #ifdef __gfx936__
      __builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
    #else
      atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
    #endif
  }
  float v2 = 0;
  if (h_low >= 0 && w_high <= width - 1) {
    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
    v2 = bottom_data[ptr2];
    grad_h_weight -= lw * v2;
    grad_w_weight += hh * v2;
    // atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
    #ifdef __gfx936__
      __builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
    #else
      atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
    #endif
  }
  float v3 = 0;
  if (h_high <= height - 1 && w_low >= 0) {
    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
    v3 = bottom_data[ptr3];
    grad_h_weight += hw * v3;
    grad_w_weight -= lh * v3;
    // atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
    #ifdef __gfx936__
      __builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
    #else
      atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
    #endif
  }
  float v4 = 0;
  if (h_high <= height - 1 && w_high <= width - 1) {
    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
    v4 = bottom_data[ptr4];
    grad_h_weight += lw * v4;
    grad_w_weight += lh * v4;
    // atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
    #ifdef __gfx936__
      __builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
    #else
      atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
    #endif
  }

  const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  // atomicAdd(grad_weights, grad_output * val);
  // atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
  // atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
  #ifdef __gfx936__
    __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
    __builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
    __builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
  #else
    atomicAdd(grad_weights, grad_output * val);
    atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
    atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
  #endif
}


__global__ void deformable_aggregation_kernel(
    const int64_t num_kernels,
    float* output,
    const float* mc_ms_feat,
    const int* spatial_shape,
    const int* scale_start_index,
    const float* sample_location,
    const float* weights,
    int batch_size,
    int num_cams,
    int num_feat,
    int num_embeds,
    int num_scale,
    int num_anchors,
    int num_pts,
    int num_groups
) {
    int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= num_kernels) return;

    const float weight = *(weights + idx / (num_embeds / num_groups));
    const int channel_index = idx % num_embeds;
    idx /= num_embeds;
    const int scale_index = idx % num_scale;
    idx /= num_scale;

    const int cam_index = idx % num_cams;
    idx /= num_cams;
    const int pts_index = idx % num_pts;
    idx /= num_pts;

    int anchor_index = idx % num_anchors;
    idx /= num_anchors;
    const int batch_index = idx % batch_size;
    idx /= batch_size;

    anchor_index = batch_index * num_anchors + anchor_index;
    const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;

    const float loc_w = sample_location[loc_offset];
    if (loc_w <= 0 || loc_w >= 1) return;
    const float loc_h = sample_location[loc_offset + 1];
    if (loc_h <= 0 || loc_h >= 1) return;
    
    int cam_scale_index = cam_index * num_scale + scale_index;
    const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;

    cam_scale_index = cam_scale_index << 1;
    const int h = spatial_shape[cam_scale_index];
    const int w = spatial_shape[cam_scale_index + 1];

    const float h_im = loc_h * h - 0.5;
    const float w_im = loc_w * w - 0.5;

    // atomicAdd(
    //     output + anchor_index * num_embeds + channel_index,
    //     bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight
    // );
    #ifdef __gfx936__
        __builtin_amdgcn_global_atomic_fadd_f32(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
    #else
        atomicAdd(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
    #endif
}


template <int blocksize>
__global__ void deformable_aggregation_grad_kernel_sp(
    const int64_t num_kernels,
    const float* mc_ms_feat,       // [bs, anchor, pts, cam, scale, channel]
    const int* spatial_shape,      // [cam, scale, 2]
    const int* scale_start_index,  // [cam, scale]
    const float* sample_location,  // [bs, anchor, pts, cam, 2(y, x)]
    const float* weights,          // [bs, anchor, cam, scale, group]
    const float* grad_output,      // [bs, anchor, c]
    float* grad_mc_ms_feat,        // same as feat
    float* grad_sampling_location, // same as sampling location
    float* grad_weights,
    int batch_size,
    int num_cams,
    int num_feat,
    int num_embeds,
    int num_scale,
    int num_anchors,
    int num_pts,
    int num_groups
) {
    extern __shared__ float2_t s_data[];

    int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= num_kernels) return;

    const int weights_ptr = idx / (num_embeds / num_groups);
    const int channel_index = idx % num_embeds;
    idx /= num_embeds;
    const int scale_index = idx % num_scale;
    idx /= num_scale;

    const int cam_index = idx % num_cams;
    idx /= num_cams;
    const int pts_index = idx % num_pts;
    idx /= num_pts;

    int anchor_index = idx % num_anchors;
    idx /= num_anchors;
    const int batch_index = idx % batch_size;
    idx /= batch_size;

    anchor_index = batch_index * num_anchors + anchor_index;

    const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;

    const float loc_w = sample_location[loc_offset];
    if (loc_w <= 0 || loc_w >= 1) return;
    const float loc_h = sample_location[loc_offset + 1];
    if (loc_h <= 0 || loc_h >= 1) return;
    
    const float grad = grad_output[anchor_index*num_embeds + channel_index];

    int cam_scale_index = cam_index * num_scale + scale_index;
    const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;

    cam_scale_index = cam_scale_index << 1;
    const int h = spatial_shape[cam_scale_index];
    const int w = spatial_shape[cam_scale_index + 1];

    const float h_im = loc_h * h - 0.5;
    const float w_im = loc_w * w - 0.5;

    /* atomicAdd( */
    /*     output + anchor_index * num_embeds + channel_index, */
    /*     bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
    /* ); */
    const float weight = weights[weights_ptr];
    float *grad_weights_ptr = grad_weights + weights_ptr;
    float *grad_location_ptr = grad_sampling_location + loc_offset;
    bilinear_sampling_grad_sp<blocksize>(
        mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
        value_offset,
        grad,
        grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr,
        s_data
    );
}


__global__ void deformable_aggregation_grad_kernel(
    const int64_t num_kernels,
    const float* mc_ms_feat,
    const int* spatial_shape,
    const int* scale_start_index,
    const float* sample_location,
    const float* weights,
    const float* grad_output,
    float* grad_mc_ms_feat,
    float* grad_sampling_location,
    float* grad_weights,
    int batch_size,
    int num_cams,
    int num_feat,
    int num_embeds,
    int num_scale,
    int num_anchors,
    int num_pts,
    int num_groups
) {
    int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= num_kernels) return;

    const int weights_ptr = idx / (num_embeds / num_groups);
    const int channel_index = idx % num_embeds;
    idx /= num_embeds;
    const int scale_index = idx % num_scale;
    idx /= num_scale;

    const int cam_index = idx % num_cams;
    idx /= num_cams;
    const int pts_index = idx % num_pts;
    idx /= num_pts;

    int anchor_index = idx % num_anchors;
    idx /= num_anchors;
    const int batch_index = idx % batch_size;
    idx /= batch_size;

    anchor_index = batch_index * num_anchors + anchor_index;
    const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;

    const float loc_w = sample_location[loc_offset];
    if (loc_w <= 0 || loc_w >= 1) return;
    const float loc_h = sample_location[loc_offset + 1];
    if (loc_h <= 0 || loc_h >= 1) return;
    
    const float grad = grad_output[anchor_index*num_embeds + channel_index];

    int cam_scale_index = cam_index * num_scale + scale_index;
    const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;

    cam_scale_index = cam_scale_index << 1;
    const int h = spatial_shape[cam_scale_index];
    const int w = spatial_shape[cam_scale_index + 1];

    const float h_im = loc_h * h - 0.5;
    const float w_im = loc_w * w - 0.5;

    /* atomicAdd( */
    /*     output + anchor_index * num_embeds + channel_index, */
    /*     bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
    /* ); */
    const float weight = weights[weights_ptr];
    float *grad_weights_ptr = grad_weights + weights_ptr;
    float *grad_location_ptr = grad_sampling_location + loc_offset;
    bilinear_sampling_grad(
        mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
        value_offset,
        grad,
        grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr
    );
}


void deformable_aggregation(
    float* output,
    const float* mc_ms_feat,
    const int* spatial_shape,
    const int* scale_start_index,
    const float* sample_location,
    const float* weights,
    int batch_size,
    int num_cams,
    int num_feat,
    int num_embeds,
    int num_scale,
    int num_anchors,
    int num_pts,
    int num_groups
) {
    const int64_t num_kernels = (int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
    deformable_aggregation_kernel
        <<<(int)ceil(((double)num_kernels/128)), 128>>>(
        num_kernels, output,
        mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
        batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
    );
}


void deformable_aggregation_grad(
  const float* mc_ms_feat,
  const int* spatial_shape,
  const int* scale_start_index,
  const float* sample_location,
  const float* weights,
  const float* grad_output,
  float* grad_mc_ms_feat,
  float* grad_sampling_location,
  float* grad_weights,
  int batch_size,
  int num_cams,
  int num_feat,
  int num_embeds,
  int num_scale,
  int num_anchors,
  int num_pts,
  int num_groups
) {
    const int64_t num_kernels =(int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
    
    if (num_embeds != 256 || ((num_embeds / num_groups) != 32)) {
          deformable_aggregation_grad_kernel
              <<<(int)ceil(((double)num_kernels/128)), 128>>>(
              num_kernels,
              mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
              grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
              batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
          );
      } else {
        int blk_dim = 256;
        deformable_aggregation_grad_kernel_sp<256>
            <<<(int)ceil(((double)num_kernels/blk_dim)), blk_dim, blk_dim * 2 * sizeof(float)>>>(
            num_kernels,
            mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
            grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
            batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
        );
      }   
}
