Commit a9dc86e9 authored by lishj6's avatar lishj6 🏸
Browse files

init_0905

parent 18eda5c1
import torch
from .deformable_aggregation import DeformableAggregationFunction
def deformable_aggregation_function(
feature_maps,
spatial_shape,
scale_start_index,
sampling_location,
weights,
):
return DeformableAggregationFunction.apply(
feature_maps,
spatial_shape,
scale_start_index,
sampling_location,
weights,
)
@torch.compile(mode="max-autotune-no-cudagraphs")
def feature_maps_format(feature_maps, inverse=False):
if inverse:
col_feats, spatial_shape, scale_start_index = feature_maps
num_cams, num_levels = spatial_shape.shape[:2]
split_size = spatial_shape[..., 0] * spatial_shape[..., 1]
split_size = split_size.cpu().numpy().tolist()
idx = 0
cam_split = [1]
cam_split_size = [sum(split_size[0])]
for i in range(num_cams - 1):
if not torch.all(spatial_shape[i] == spatial_shape[i + 1]):
cam_split.append(0)
cam_split_size.append(0)
cam_split[-1] += 1
cam_split_size[-1] += sum(split_size[i + 1])
mc_feat = [
x.unflatten(1, (cam_split[i], -1))
for i, x in enumerate(col_feats.split(cam_split_size, dim=1))
]
spatial_shape = spatial_shape.cpu().numpy().tolist()
mc_ms_feat = []
shape_index = 0
for i, feat in enumerate(mc_feat):
feat = list(feat.split(split_size[shape_index], dim=2))
for j, f in enumerate(feat):
feat[j] = f.unflatten(2, spatial_shape[shape_index][j])
feat[j] = feat[j].permute(0, 1, 4, 2, 3)
mc_ms_feat.append(feat)
shape_index += cam_split[i]
return mc_ms_feat
if isinstance(feature_maps[0], (list, tuple)):
formated = [feature_maps_format(x) for x in feature_maps]
col_feats = torch.cat([x[0] for x in formated], dim=1)
spatial_shape = torch.cat([x[1] for x in formated], dim=0)
scale_start_index = torch.cat([x[2] for x in formated], dim=0)
return [col_feats, spatial_shape, scale_start_index]
bs, num_cams = feature_maps[0].shape[:2]
spatial_shape = []
col_feats = []
for i, feat in enumerate(feature_maps):
spatial_shape.append(feat.shape[-2:])
col_feats.append(
torch.reshape(feat, (bs, num_cams, feat.shape[2], -1))
)
col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2).flatten(1, 2)
spatial_shape = [spatial_shape] * num_cams
spatial_shape = torch.tensor(
spatial_shape,
dtype=torch.int64,
device=col_feats.device,
)
scale_start_index = spatial_shape[..., 0] * spatial_shape[..., 1]
scale_start_index = scale_start_index.flatten().cumsum(dim=0)
scale_start_index = torch.cat(
[torch.tensor([0]).to(scale_start_index), scale_start_index[:-1]]
)
scale_start_index = scale_start_index.reshape(num_cams, -1)
feature_maps = [
col_feats,
spatial_shape,
scale_start_index,
]
return feature_maps
# ninja log v5
3 24593 1756893458602064377 /home/Sparse4D/projects/mmdet3d_plugin/ops/build/temp.linux-x86_64-cpython-310/src/deformable_aggregation_hip.o 71d1d5836deab680
2 24850 1756893458862064385 /home/Sparse4D/projects/mmdet3d_plugin/ops/build/temp.linux-x86_64-cpython-310/src/deformable_aggregation_cuda.o ae0f4acb73ed65bb
ninja_required_version = 1.3
cxx = c++
nvcc = /opt/dtk/bin/hipcc
cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -DWITH_CUDA -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10 -c
post_cflags = -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deformable_aggregation_ext -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++17
cuda_cflags = -DWITH_CUDA -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10 -c
cuda_post_cflags = -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deformable_aggregation_ext -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx906 --offload-arch=gfx926 --offload-arch=gfx928 --offload-arch=gfx936 -fno-gpu-rdc -std=c++17
cuda_dlink_post_cflags =
ldflags =
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
build /home/Sparse4D/projects/mmdet3d_plugin/ops/build/temp.linux-x86_64-cpython-310/src/deformable_aggregation_cuda.o: cuda_compile /home/Sparse4D/projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.hip
build /home/Sparse4D/projects/mmdet3d_plugin/ops/build/temp.linux-x86_64-cpython-310/src/deformable_aggregation_hip.o: compile /home/Sparse4D/projects/mmdet3d_plugin/ops/src/deformable_aggregation_hip.cpp
import torch
from torch.autograd.function import Function, once_differentiable
from . import deformable_aggregation_ext
class DeformableAggregationFunction(Function):
@staticmethod
def forward(
ctx,
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
):
# output: [bs, num_pts, num_embeds]
mc_ms_feat = mc_ms_feat.contiguous().float()
spatial_shape = spatial_shape.contiguous().int()
scale_start_index = scale_start_index.contiguous().int()
sampling_location = sampling_location.contiguous().float()
weights = weights.contiguous().float()
output = deformable_aggregation_ext.deformable_aggregation_forward(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
)
ctx.save_for_backward(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
) = ctx.saved_tensors
mc_ms_feat = mc_ms_feat.contiguous().float()
spatial_shape = spatial_shape.contiguous().int()
scale_start_index = scale_start_index.contiguous().int()
sampling_location = sampling_location.contiguous().float()
weights = weights.contiguous().float()
grad_mc_ms_feat = torch.zeros_like(mc_ms_feat)
grad_sampling_location = torch.zeros_like(sampling_location)
grad_weights = torch.zeros_like(weights)
deformable_aggregation_ext.deformable_aggregation_backward(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
grad_output.contiguous(),
grad_mc_ms_feat,
grad_sampling_location,
grad_weights,
)
return (
grad_mc_ms_feat,
None,
None,
grad_sampling_location,
grad_weights,
)
Metadata-Version: 2.1
Name: deformable-aggregation-ext
Version: 0.0.0
setup.py
deformable_aggregation_ext.egg-info/PKG-INFO
deformable_aggregation_ext.egg-info/SOURCES.txt
deformable_aggregation_ext.egg-info/dependency_links.txt
deformable_aggregation_ext.egg-info/top_level.txt
src/deformable_aggregation.cpp
src/deformable_aggregation_cuda.cu
src/deformable_aggregation_cuda.hip
src/deformable_aggregation_hip.cpp
\ No newline at end of file
import os
import torch
from setuptools import setup
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDAExtension,
)
def make_cuda_ext(
name,
module,
sources,
sources_cuda=[],
extra_args=[],
extra_include_path=[],
):
define_macros = []
extra_compile_args = {"cxx": [] + extra_args}
if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1":
define_macros += [("WITH_CUDA", None)]
extension = CUDAExtension
extra_compile_args["nvcc"] = extra_args + [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
sources += sources_cuda
else:
print("Compiling {} without CUDA".format(name))
extension = CppExtension
return extension(
name="{}.{}".format(module, name),
sources=[os.path.join(*module.split("."), p) for p in sources],
include_dirs=extra_include_path,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
if __name__ == "__main__":
setup(
name="deformable_aggregation_ext",
ext_modules=[
make_cuda_ext(
"deformable_aggregation_ext",
module=".",
sources=[
f"src/deformable_aggregation.cpp",
f"src/deformable_aggregation_cuda.cu",
],
),
],
cmdclass={"build_ext": BuildExtension},
)
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
);
/* feat: bs, num_feat, c */
/* _spatial_shape: cam, scale, 2 */
/* _scale_start_index: cam, scale */
/* _sampling_location: bs, anchor, pts, cam, 2 */
/* _weights: bs, anchor, pts, cam, scale, group */
/* output: bs, anchor, c */
/* kernel: bs, anchor, pts, c */
at::Tensor deformable_aggregation_forward(
const at::Tensor &_mc_ms_feat,
const at::Tensor &_spatial_shape,
const at::Tensor &_scale_start_index,
const at::Tensor &_sampling_location,
const at::Tensor &_weights
) {
at::DeviceGuard guard(_mc_ms_feat.device());
const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
int batch_size = _mc_ms_feat.size(0);
int num_feat = _mc_ms_feat.size(1);
int num_embeds = _mc_ms_feat.size(2);
int num_cams = _spatial_shape.size(0);
int num_scale = _spatial_shape.size(1);
int num_anchors = _sampling_location.size(1);
int num_pts = _sampling_location.size(2);
int num_groups = _weights.size(5);
const float* mc_ms_feat = _mc_ms_feat.data_ptr<float>();
const int* spatial_shape = _spatial_shape.data_ptr<int>();
const int* scale_start_index = _scale_start_index.data_ptr<int>();
const float* sampling_location = _sampling_location.data_ptr<float>();
const float* weights = _weights.data_ptr<float>();
auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options());
deformable_aggregation(
output.data_ptr<float>(),
mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
return output;
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
);
void deformable_aggregation_backward(
const at::Tensor &_mc_ms_feat,
const at::Tensor &_spatial_shape,
const at::Tensor &_scale_start_index,
const at::Tensor &_sampling_location,
const at::Tensor &_weights,
const at::Tensor &_grad_output,
at::Tensor &_grad_mc_ms_feat,
at::Tensor &_grad_sampling_location,
at::Tensor &_grad_weights
) {
at::DeviceGuard guard(_mc_ms_feat.device());
const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
int batch_size = _mc_ms_feat.size(0);
int num_feat = _mc_ms_feat.size(1);
int num_embeds = _mc_ms_feat.size(2);
int num_cams = _spatial_shape.size(0);
int num_scale = _spatial_shape.size(1);
int num_anchors = _sampling_location.size(1);
int num_pts = _sampling_location.size(2);
int num_groups = _weights.size(5);
const float* mc_ms_feat = _mc_ms_feat.data_ptr<float>();
const int* spatial_shape = _spatial_shape.data_ptr<int>();
const int* scale_start_index = _scale_start_index.data_ptr<int>();
const float* sampling_location = _sampling_location.data_ptr<float>();
const float* weights = _weights.data_ptr<float>();
const float* grad_output = _grad_output.data_ptr<float>();
float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr<float>();
float* grad_sampling_location = _grad_sampling_location.data_ptr<float>();
float* grad_weights = _grad_weights.data_ptr<float>();
deformable_aggregation_grad(
mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"deformable_aggregation_forward",
&deformable_aggregation_forward,
"deformable_aggregation_forward"
);
m.def(
"deformable_aggregation_backward",
&deformable_aggregation_backward,
"deformable_aggregation_backward"
);
}
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
);
/* feat: bs, num_feat, c */
/* _spatial_shape: cam, scale, 2 */
/* _scale_start_index: cam, scale */
/* _sampling_location: bs, anchor, pts, cam, 2 */
/* _weights: bs, anchor, pts, cam, scale, group */
/* output: bs, anchor, c */
/* kernel: bs, anchor, pts, c */
at::Tensor deformable_aggregation_forward(
const at::Tensor &_mc_ms_feat,
const at::Tensor &_spatial_shape,
const at::Tensor &_scale_start_index,
const at::Tensor &_sampling_location,
const at::Tensor &_weights
) {
at::DeviceGuard guard(_mc_ms_feat.device());
const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
int batch_size = _mc_ms_feat.size(0);
int num_feat = _mc_ms_feat.size(1);
int num_embeds = _mc_ms_feat.size(2);
int num_cams = _spatial_shape.size(0);
int num_scale = _spatial_shape.size(1);
int num_anchors = _sampling_location.size(1);
int num_pts = _sampling_location.size(2);
int num_groups = _weights.size(5);
const float* mc_ms_feat = _mc_ms_feat.data_ptr<float>();
const int* spatial_shape = _spatial_shape.data_ptr<int>();
const int* scale_start_index = _scale_start_index.data_ptr<int>();
const float* sampling_location = _sampling_location.data_ptr<float>();
const float* weights = _weights.data_ptr<float>();
auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options());
deformable_aggregation(
output.data_ptr<float>(),
mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
return output;
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
);
void deformable_aggregation_backward(
const at::Tensor &_mc_ms_feat,
const at::Tensor &_spatial_shape,
const at::Tensor &_scale_start_index,
const at::Tensor &_sampling_location,
const at::Tensor &_weights,
const at::Tensor &_grad_output,
at::Tensor &_grad_mc_ms_feat,
at::Tensor &_grad_sampling_location,
at::Tensor &_grad_weights
) {
at::DeviceGuard guard(_mc_ms_feat.device());
const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
int batch_size = _mc_ms_feat.size(0);
int num_feat = _mc_ms_feat.size(1);
int num_embeds = _mc_ms_feat.size(2);
int num_cams = _spatial_shape.size(0);
int num_scale = _spatial_shape.size(1);
int num_anchors = _sampling_location.size(1);
int num_pts = _sampling_location.size(2);
int num_groups = _weights.size(5);
const float* mc_ms_feat = _mc_ms_feat.data_ptr<float>();
const int* spatial_shape = _spatial_shape.data_ptr<int>();
const int* scale_start_index = _scale_start_index.data_ptr<int>();
const float* sampling_location = _sampling_location.data_ptr<float>();
const float* weights = _weights.data_ptr<float>();
const float* grad_output = _grad_output.data_ptr<float>();
float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr<float>();
float* grad_sampling_location = _grad_sampling_location.data_ptr<float>();
float* grad_weights = _grad_weights.data_ptr<float>();
deformable_aggregation_grad(
mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"deformable_aggregation_forward",
&deformable_aggregation_forward,
"deformable_aggregation_forward"
);
m.def(
"deformable_aggregation_backward",
&deformable_aggregation_backward,
"deformable_aggregation_backward"
);
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <iostream>
#include <stdlib.h>
__device__ float bilinear_sampling(
const float *&bottom_data, const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr
) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
struct float2_t{
float a;
float b;
};
__forceinline__ __device__
float2_t warp_reduce_sum(float2_t val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) {
val.a += __shfl_down(val.a, offset);
val.b += __shfl_down(val.b, offset);
}
return val;
}
template <int blocksize>
__forceinline__ __device__
float2_t block_reduce_sum(float2_t val, float2_t* shared) {
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
constexpr int share_size = blocksize / 64;
val = warp_reduce_sum(val);
if constexpr (blocksize == 64) return val;
if (lid == 0 && wid < share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0 && lid < share_size) {
val = shared[lid];
val = warp_reduce_sum(val, share_size / 2);
}
return val;
}
template <int blocksize>
__device__ void bilinear_sampling_grad_sp(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights,
float2_t* s_data) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
const int valid1 = (h_low >= 0 && w_low >= 0);
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
float v1 = valid1 ? bottom_data[ptr1] : 0.0f;
if (valid1) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
}
const int valid2 = (h_low >= 0 && w_high <= width - 1);
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
float v2 = valid2 ? bottom_data[ptr2] : 0.0f;
if (valid2) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
}
const int valid3 = (h_high <= height - 1 && w_low >= 0);
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
float v3 = valid3 ? bottom_data[ptr3] : 0.0f;
if (valid3) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
}
const int valid4 = (h_high <= height - 1 && w_high <= width - 1);
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
float v4 = valid4 ? bottom_data[ptr4] : 0.0f;
if (valid4) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
}
grad_h_weight += (-hw * v1) + (-lw * v2) + ( hw * v3) + ( lw * v4);
grad_w_weight += (-hh * v1) + ( hh * v2) + (-lh * v3) + ( lh * v4);
float2_t spl;
spl.a = width * grad_w_weight * top_grad_mc_ms_feat;
spl.b = height * grad_h_weight * top_grad_mc_ms_feat;
spl = block_reduce_sum<blocksize>(spl, s_data);
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
float wei = grad_output * val;
for (int offset=16; offset>=1; offset >>= 1) {
wei += __shfl_down(wei, offset);
}
#ifdef __gfx936__
// __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
if (threadIdx.x % 32 == 0) {
// __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, wei);
*grad_weights += wei;
}
if (threadIdx.x ==0) {
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, spl.a);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, spl.b);
}
#else
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#endif
}
__device__ void bilinear_sampling_grad(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
// atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
// atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
// atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
}
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#else
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#endif
}
__global__ void deformable_aggregation_kernel(
const int64_t num_kernels,
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const float weight = *(weights + idx / (num_embeds / num_groups));
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
// atomicAdd(
// output + anchor_index * num_embeds + channel_index,
// bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight
// );
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
#else
atomicAdd(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
#endif
}
template <int blocksize>
__global__ void deformable_aggregation_grad_kernel_sp(
const int64_t num_kernels,
const float* mc_ms_feat, // [bs, anchor, pts, cam, scale, channel]
const int* spatial_shape, // [cam, scale, 2]
const int* scale_start_index, // [cam, scale]
const float* sample_location, // [bs, anchor, pts, cam, 2(y, x)]
const float* weights, // [bs, anchor, cam, scale, group]
const float* grad_output, // [bs, anchor, c]
float* grad_mc_ms_feat, // same as feat
float* grad_sampling_location, // same as sampling location
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
extern __shared__ float2_t s_data[];
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad_sp<blocksize>(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr,
s_data
);
}
__global__ void deformable_aggregation_grad_kernel(
const int64_t num_kernels,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr
);
}
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int64_t num_kernels = (int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
deformable_aggregation_kernel
<<<(int)ceil(((double)num_kernels/128)), 128>>>(
num_kernels, output,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int64_t num_kernels = (int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
if (num_embeds != 256 || ((num_embeds / num_groups) != 32)) {
deformable_aggregation_grad_kernel
<<<(int)ceil(((double)num_kernels/128)), 128>>>(
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
} else {
int blk_dim = 256;
deformable_aggregation_grad_kernel_sp<256>
<<<(int)ceil(((double)num_kernels/blk_dim)), blk_dim, blk_dim * 2 * sizeof(float)>>>(
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <iostream>
#include <stdlib.h>
__device__ float bilinear_sampling(
const float *&bottom_data, const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr
) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
__device__ void bilinear_sampling_grad(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
}
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
}
__global__ void deformable_aggregation_kernel(
const int num_kernels,
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const float weight = *(weights + idx / (num_embeds / num_groups));
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
atomicAdd(
output + anchor_index * num_embeds + channel_index,
bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight
);
}
__global__ void deformable_aggregation_grad_kernel(
const int num_kernels,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr
);
}
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
deformable_aggregation_kernel
<<<(int)ceil(((double)num_kernels/128)), 128>>>(
num_kernels, output,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
deformable_aggregation_grad_kernel
<<<(int)ceil(((double)num_kernels/128)), 128>>>(
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <THH/THHAtomics.cuh>
#include <iostream>
#include <stdlib.h>
__device__ float bilinear_sampling(
const float *&bottom_data, const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr
) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
struct float2_t{
float a;
float b;
};
__forceinline__ __device__
float2_t warp_reduce_sum(float2_t val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) {
val.a += __shfl_down(val.a, offset);
val.b += __shfl_down(val.b, offset);
}
return val;
}
template <int blocksize>
__forceinline__ __device__
float2_t block_reduce_sum(float2_t val, float2_t* shared) {
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
constexpr int share_size = blocksize / 64;
val = warp_reduce_sum(val);
if constexpr (blocksize == 64) return val;
if (lid == 0 && wid < share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0 && lid < share_size) {
val = shared[lid];
val = warp_reduce_sum(val, share_size / 2);
}
return val;
}
template <int blocksize>
__device__ void bilinear_sampling_grad_sp(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights,
float2_t* s_data) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
const int valid1 = (h_low >= 0 && w_low >= 0);
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
float v1 = valid1 ? bottom_data[ptr1] : 0.0f;
if (valid1) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
}
const int valid2 = (h_low >= 0 && w_high <= width - 1);
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
float v2 = valid2 ? bottom_data[ptr2] : 0.0f;
if (valid2) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
}
const int valid3 = (h_high <= height - 1 && w_low >= 0);
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
float v3 = valid3 ? bottom_data[ptr3] : 0.0f;
if (valid3) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
}
const int valid4 = (h_high <= height - 1 && w_high <= width - 1);
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
float v4 = valid4 ? bottom_data[ptr4] : 0.0f;
if (valid4) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
}
grad_h_weight += (-hw * v1) + (-lw * v2) + ( hw * v3) + ( lw * v4);
grad_w_weight += (-hh * v1) + ( hh * v2) + (-lh * v3) + ( lh * v4);
float2_t spl;
spl.a = width * grad_w_weight * top_grad_mc_ms_feat;
spl.b = height * grad_h_weight * top_grad_mc_ms_feat;
spl = block_reduce_sum<blocksize>(spl, s_data);
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
float wei = grad_output * val;
for (int offset=16; offset>=1; offset >>= 1) {
wei += __shfl_down(wei, offset);
}
#ifdef __gfx936__
// __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
if (threadIdx.x % 32 == 0) {
// __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, wei);
*grad_weights += wei;
}
if (threadIdx.x ==0) {
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, spl.a);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, spl.b);
}
#else
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#endif
}
__device__ void bilinear_sampling_grad(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
// atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
// atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
// atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
}
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#else
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#endif
}
__global__ void deformable_aggregation_kernel(
const int64_t num_kernels,
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const float weight = *(weights + idx / (num_embeds / num_groups));
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
// atomicAdd(
// output + anchor_index * num_embeds + channel_index,
// bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight
// );
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
#else
atomicAdd(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
#endif
}
template <int blocksize>
__global__ void deformable_aggregation_grad_kernel_sp(
const int64_t num_kernels,
const float* mc_ms_feat, // [bs, anchor, pts, cam, scale, channel]
const int* spatial_shape, // [cam, scale, 2]
const int* scale_start_index, // [cam, scale]
const float* sample_location, // [bs, anchor, pts, cam, 2(y, x)]
const float* weights, // [bs, anchor, cam, scale, group]
const float* grad_output, // [bs, anchor, c]
float* grad_mc_ms_feat, // same as feat
float* grad_sampling_location, // same as sampling location
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
extern __shared__ float2_t s_data[];
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad_sp<blocksize>(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr,
s_data
);
}
__global__ void deformable_aggregation_grad_kernel(
const int64_t num_kernels,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr
);
}
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int64_t num_kernels = (int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
hipLaunchKernelGGL(( deformable_aggregation_kernel)
, dim3((int)ceil(((double)num_kernels/128))), dim3(128), 0, 0,
num_kernels, output,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int64_t num_kernels = (int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
if (num_embeds != 256 || ((num_embeds / num_groups) != 32)) {
hipLaunchKernelGGL(( deformable_aggregation_grad_kernel)
, dim3((int)ceil(((double)num_kernels/128))), dim3(128), 0, 0,
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
} else {
int blk_dim = 256;
hipLaunchKernelGGL(( deformable_aggregation_grad_kernel_sp<256>)
, dim3((int)ceil(((double)num_kernels/blk_dim))), dim3(blk_dim), blk_dim * 2 * sizeof(float), 0,
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment