Commit 40e15362 authored by hanbao's avatar hanbao Committed by hubertlu-tw
Browse files

add customized fused op index mulitiplication (#1438)


Co-authored-by: default avatarHan Bao <hbao@nvidia.com>
parent 96850dfa
#include <torch/torch.h>
#include <vector>
#include <cstdint>
void index_mul_2d_float_foward_cuda(at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_half_foward_cuda(at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void index_mul_2d_float_forward(
at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_float_foward_cuda(out, in1, in2, idx1);
}
void index_mul_2d_float_backward(
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_float_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1);
}
void index_mul_2d_float_backwrad_backward(
at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_float_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1);
}
void index_mul_2d_half_forward(
at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_half_foward_cuda(out, in1, in2, idx1);
}
void index_mul_2d_half_backward(
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_half_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1);
}
void index_mul_2d_half_backwrad_backward(
at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1)
{
return index_mul_2d_half_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("float_forward", &index_mul_2d_float_forward,
"index mul float calculation forward (CUDA)");
m.def("float_backward", &index_mul_2d_float_backward,
"index mul float calculation backward (CUDA)");
m.def("float_backward_backward", &index_mul_2d_float_backwrad_backward,
"index mul float calculation backward backward (CUDA)");
m.def("half_forward", &index_mul_2d_half_forward,
"index mul half calculation forward (CUDA)");
m.def("half_backward", &index_mul_2d_half_backward,
"index mul half calculation backward (CUDA)");
m.def("half_backward_backward", &index_mul_2d_half_backwrad_backward,
"index mul half calculation backward backward (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Atomic.cuh>
__global__ void index_mul_2d_float_dim64(
float *out,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
constexpr int fea_dim = 64;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;
int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;
float4 res, src1, src2;
src1 = reinterpret_cast<const float4 *>(in1)[vec_idx1];
src2 = reinterpret_cast<const float4 *>(in2)[vec_idx2];
res.x = src1.x * src2.x;
res.y = src1.y * src2.y;
res.z = src1.z * src2.z;
res.w = src1.w * src2.w;
reinterpret_cast<float4 *>(out)[vec_idx2] = res;
}
}
__global__ void index_mul_2d_float(
float *out,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim);
int64_t vec_idx2 = (start_idx * fea_dim);
for (int i = tidx; i < fea_dim; i += stride) {
out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i];
}
}
}
__global__ void index_mul_2d_half(
at::Half *out,
const at::Half *in1,
const at::Half *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim);
int64_t vec_idx2 = (start_idx * fea_dim);
for (int i = tidx; i < fea_dim; i += stride) {
out[vec_idx2 + i] = at::Half(static_cast<float>(in1[vec_idx1 + i]) * static_cast<float>(in2[vec_idx2 + i]));
}
}
}
__global__ void index_mul_2d_grad_float_dim64(
float *grad_in1,
float *grad_in2,
const float *grad_out,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
constexpr int fea_dim = 64;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;
int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;
float4 src_in1, src_in2, src_grad_out, dst_grad_in2;
src_grad_out = reinterpret_cast<const float4 *>(grad_out)[vec_idx2];
src_in1 = reinterpret_cast<const float4 *>(in1)[vec_idx1];
src_in2 = reinterpret_cast<const float4 *>(in2)[vec_idx2];
int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4;
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_out.x * src_in2.x);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_out.y * src_in2.y);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_out.z * src_in2.z);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_out.w * src_in2.w);
dst_grad_in2.x = src_grad_out.x * src_in1.x;
dst_grad_in2.y = src_grad_out.y * src_in1.y;
dst_grad_in2.z = src_grad_out.z * src_in1.z;
dst_grad_in2.w = src_grad_out.w * src_in1.w;
reinterpret_cast<float4 *>(grad_in2)[vec_idx2] = dst_grad_in2;
}
}
__global__ void index_mul_2d_grad_float(
float *grad_in1,
float *grad_in2,
const float *grad_out,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = idx1[start_idx] * fea_dim;
int64_t vec_idx2 = start_idx * fea_dim;
for (int i = tidx; i < fea_dim; i += stride) {
float src_in1 = in1[vec_idx1 + i];
float src_in2 = in2[vec_idx2 + i];
float src_grad_out = grad_out[vec_idx2 + i];
grad_in2[vec_idx2 + i] = src_grad_out * src_in1;
gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_out * src_in2);
}
}
}
__global__ void index_mul_2d_grad_half(
at::Half *grad_in1,
at::Half *grad_in2,
const at::Half *grad_out,
const at::Half *in1,
const at::Half *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = idx1[start_idx] * fea_dim;
int64_t vec_idx2 = start_idx * fea_dim;
for (int i = tidx; i < fea_dim; i += stride) {
float src_in1 = static_cast<float>(in1[vec_idx1 + i]);
float src_in2 = static_cast<float>(in2[vec_idx2 + i]);
float src_grad_out = static_cast<float>(grad_out[vec_idx2 + i]);
grad_in2[vec_idx2 + i] = at::Half(src_grad_out * src_in1);
gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_out * src_in2));
}
}
}
__global__ void index_mul_2d_grad_grad_float_dim64(
float *grad_grad_out,
float *grad_in1,
float *grad_in2,
const float *grad_out,
const float *grad_grad_in1,
const float *grad_grad_in2,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
constexpr int fea_dim = 64;
if (start_idx < size) {
int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;
int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;
float4 src_grad_grad_in1, src_in1, src_grad_grad_in2, src_in2, src_grad_out;
float4 dst_grad_grad_out, dst_grad_in2;
src_grad_grad_in1 = reinterpret_cast<const float4 *>(grad_grad_in1)[vec_idx1];
src_in1 = reinterpret_cast<const float4 *>(in1)[vec_idx1];
src_grad_grad_in2 = reinterpret_cast<const float4 *>(grad_grad_in2)[vec_idx2];
src_in2 = reinterpret_cast<const float4 *>(in2)[vec_idx2];
dst_grad_grad_out.x = src_grad_grad_in1.x * src_in2.x + src_grad_grad_in2.x * src_in1.x;
dst_grad_grad_out.y = src_grad_grad_in1.y * src_in2.y + src_grad_grad_in2.y * src_in1.y;
dst_grad_grad_out.z = src_grad_grad_in1.z * src_in2.z + src_grad_grad_in2.z * src_in1.z;
dst_grad_grad_out.w = src_grad_grad_in1.w * src_in2.w + src_grad_grad_in2.w * src_in1.w;
reinterpret_cast<float4 *>(grad_grad_out)[vec_idx2] = dst_grad_grad_out;
src_grad_out = reinterpret_cast<const float4 *>(grad_out)[vec_idx2];
int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4;
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_grad_in2.x * src_grad_out.x);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_grad_in2.y * src_grad_out.y);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_grad_in2.z * src_grad_out.z);
gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_grad_in2.w * src_grad_out.w);
dst_grad_in2.x = src_grad_grad_in1.x * src_grad_out.x;
dst_grad_in2.y = src_grad_grad_in1.y * src_grad_out.y;
dst_grad_in2.z = src_grad_grad_in1.z * src_grad_out.z;
dst_grad_in2.w = src_grad_grad_in1.w * src_grad_out.w;
reinterpret_cast<float4 *>(grad_in2)[vec_idx2] = dst_grad_in2;
}
}
__global__ void index_mul_2d_grad_grad_float(
float *grad_grad_out,
float *grad_in1,
float *grad_in2,
const float *grad_out,
const float *grad_grad_in1,
const float *grad_grad_in2,
const float *in1,
const float *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = idx1[start_idx] * fea_dim;
int64_t vec_idx2 = start_idx * fea_dim;
for (int i = tidx; i < fea_dim; i += stride) {
float src_grad_grad_in1 = grad_grad_in1[vec_idx1 + i];
float src_grad_grad_in2 = grad_grad_in2[vec_idx2 + i];
float src_in1 = in1[vec_idx1 + i];
float src_in2 = in2[vec_idx2 + i];
float src_grad_out = grad_out[vec_idx2 + i];
grad_grad_out[vec_idx2 + i] = src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1;
grad_in2[vec_idx2 + i] = src_grad_grad_in1 * src_grad_out;
gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_grad_in2 * src_grad_out);
}
}
}
__global__ void index_mul_2d_grad_grad_half(
at::Half *grad_grad_out,
at::Half *grad_in1,
at::Half *grad_in2,
const at::Half *grad_out,
const at::Half *grad_grad_in1,
const at::Half *grad_grad_in2,
const at::Half *in1,
const at::Half *in2,
const int64_t *idx1,
const int64_t size,
const int64_t fea_dim)
{
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int bidx = blockIdx.x;
const int start_idx = bidx * blockDim.y + tidy;
const int stride = blockDim.x;
if (start_idx < size) {
int64_t vec_idx1 = idx1[start_idx] * fea_dim;
int64_t vec_idx2 = start_idx * fea_dim;
for (int i = tidx; i < fea_dim; i += stride) {
float src_grad_grad_in1 = static_cast<float>(grad_grad_in1[vec_idx1 + i]);
float src_grad_grad_in2 = static_cast<float>(grad_grad_in2[vec_idx2 + i]);
float src_in1 = static_cast<float>(in1[vec_idx1 + i]);
float src_in2 = static_cast<float>(in2[vec_idx2 + i]);
float src_grad_out = static_cast<float>(grad_out[vec_idx2 + i]);
grad_grad_out[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1);
grad_in2[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_grad_out);
gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_grad_in2 * src_grad_out));
}
}
}
void index_mul_2d_float_foward_cuda(at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fea_dim == 64) {
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(),
idx1.data_ptr<int64_t>(), size);
} else {
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(),
idx1.data_ptr<int64_t>(), size, fea_dim);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fea_dim == 64) {
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_grad_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);
AT_CUDA_CHECK(cudaGetLastError());
} else {
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_grad_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);
}
}
void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fea_dim == 64) {
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_grad_grad_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),
grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);
} else {
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_grad_grad_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),
grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void index_mul_2d_half_foward_cuda(at::Tensor &out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
out.data_ptr<at::Half>(), in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(),
idx1.data_ptr<int64_t>(), size, fea_dim);
AT_CUDA_CHECK(cudaGetLastError());
}
void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_grad_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(), grad_out.data_ptr<at::Half>(),
in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);
}
void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out,
at::Tensor &grad_in1,
at::Tensor &grad_in2,
const at::Tensor &grad_out,
const at::Tensor &grad_grad_in1,
const at::Tensor &grad_grad_in2,
const at::Tensor &in1,
const at::Tensor &in2,
const at::Tensor &idx1) {
const int64_t size = in2.size(0);
const int64_t fea_dim = in2.size(1);
if (size < 0){
return;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
index_mul_2d_grad_grad_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
grad_grad_out.data_ptr<at::Half>(), grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(),
grad_out.data_ptr<at::Half>(), grad_grad_in1.data_ptr<at::Half>(), grad_grad_in2.data_ptr<at::Half>(),
in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);
AT_CUDA_CHECK(cudaGetLastError());
}
\ No newline at end of file
from .index_mul_2d import index_mul_2d
import torch
import fused_index_mul_2d
class IndexMul2d_(torch.autograd.Function):
'''
Currently only support index in dimension 0 with a 2-dimension tensor.
The shape of indexed in1 must be same with in2. Now this kernel does not support broadcast.
The datatype must be float32 or float16.
'''
@staticmethod
def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor) -> torch.Tensor:
assert in2.size(0) == idx1.size(0)
if ((in1.dtype != torch.float32 and in1.dtype != torch.half) or in2.dtype != in1.dtype):
raise RuntimeError("input1'dtype and input2's dtype must be fp32 or fp16. And input type must be same")
if (in1.dim() != 2 or in2.dim() != 2):
raise RuntimeError("in1 and in2 must be 2-dimension tensor.")
if (idx1.dim() != 1):
raise RuntimeError("idx1 must be 1-dimension tensor.")
if not in1.is_contiguous():
in1 = in1.contiguous()
if not in2.is_contiguous():
in2 = in2.contiguous()
if not idx1.is_contiguous():
idx1 = idx1.contiguous()
assert in1.is_contiguous()
assert in2.is_contiguous()
assert idx1.is_contiguous()
out = torch.empty_like(in2)
if (in1.dtype == torch.float32):
fused_index_mul_2d.float_forward(
out,
in1,
in2,
idx1)
elif (in1.dtype == torch.half):
fused_index_mul_2d.half_forward(
out,
in1,
in2,
idx1)
ctx.for_backwards = (in1, in2, idx1)
return out
@staticmethod
def backward(ctx, grad_out):
in1, in2, idx1 = ctx.for_backwards
grad_in1, grad_in2 = index_mul_2d_backward(in1, in2, idx1, grad_out)
return grad_in1, grad_in2, None
class IndexMul2dBackward_(torch.autograd.Function):
@staticmethod
def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor,
grad_out: torch.Tensor) -> torch.Tensor:
if not in1.is_contiguous():
in1 = in1.contiguous()
if not in2.is_contiguous():
in2 = in2.contiguous()
if not idx1.is_contiguous():
idx1 = idx1.contiguous()
if not grad_out.is_contiguous():
grad_out = grad_out.contiguous()
assert in1.is_contiguous()
assert in2.is_contiguous()
assert idx1.is_contiguous()
assert grad_out.is_contiguous()
grad_in1 = torch.zeros_like(in1)
grad_in2 = torch.empty_like(in2)
if (in1.dtype == torch.float32):
fused_index_mul_2d.float_backward(
grad_in1,
grad_in2,
grad_out,
in1,
in2,
idx1)
elif (in1.dtype == torch.half):
fused_index_mul_2d.half_backward(
grad_in1,
grad_in2,
grad_out,
in1,
in2,
idx1)
ctx.for_backwards = (in1, in2, idx1, grad_out)
return grad_in1, grad_in2
@staticmethod
def backward(ctx, grad_grad_in1, grad_grad_in2):
if not grad_grad_in1.is_contiguous():
grad_grad_in1 = grad_grad_in1.contiguous()
if not grad_grad_in2.is_contiguous():
grad_grad_in2 = grad_grad_in2.contiguous()
assert grad_grad_in1.is_contiguous()
assert grad_grad_in2.is_contiguous()
in1, in2, idx1, grad_out = ctx.for_backwards
grad_in1 = torch.zeros_like(in1)
grad_in2 = torch.empty_like(in2)
grad_grad_out = torch.empty_like(grad_out)
if (in1.dtype == torch.float32):
fused_index_mul_2d.float_backward_backward(
grad_grad_out,
grad_in1,
grad_in2,
grad_out,
grad_grad_in1,
grad_grad_in2,
in1,
in2,
idx1)
elif (in1.dtype == torch.half):
fused_index_mul_2d.half_backward_backward(
grad_grad_out,
grad_in1,
grad_in2,
grad_out,
grad_grad_in1,
grad_grad_in2,
in1,
in2,
idx1)
return grad_in1, grad_in2, None, grad_grad_out
index_mul_2d = IndexMul2d_.apply
index_mul_2d_backward = IndexMul2dBackward_.apply
import random
import unittest
import torch
import torch.nn.functional as F
HAS_INDEX_MUL_2D_RELU = None
try:
from apex.contrib.index_mul_2d import index_mul_2d
except ImportError as e:
HAS_INDEX_MUL_2D_RELU = False
else:
HAS_INDEX_MUL_2D_RELU = True
@unittest.skipIf(not HAS_INDEX_MUL_2D_RELU, "`apex.contrib.index_mul_2d` is not found.")
class IndexMul2dTest(unittest.TestCase):
def setUp(self, seed=0):
torch.manual_seed(seed)
self.input1_size = random.randint(1, 1000)
self.input2_size = random.randint(1, 100000)
self.feature_size = random.randint(1, 256)
self.input1_float = torch.randn(size=(self.input1_size, self.feature_size),).cuda()
self.input2_float = torch.randn(size=(self.input2_size, self.feature_size),).cuda()
self.index1 = torch.randint(low=0, high=self.input1_size, size=(self.input2_size,)).cuda()
self.input1_float_ = self.input1_float.clone()
self.input2_float_ = self.input2_float.clone()
self.input1_float.requires_grad_()
self.input1_float_.requires_grad_()
self.input2_float.requires_grad_()
self.input2_float_.requires_grad_()
self.input1_half = torch.randn(size=(self.input1_size, self.feature_size),).cuda().half()
self.input2_half = torch.randn(size=(self.input2_size, self.feature_size),).cuda().half()
self.input1_half_ = self.input1_half.clone()
self.input2_half_ = self.input2_half.clone()
self.input1_half.requires_grad_()
self.input2_half.requires_grad_()
self.input1_half_.requires_grad_()
self.input2_half_.requires_grad_()
def test_index_mul_float(self):
out = index_mul_2d(self.input1_float, self.input2_float, self.index1)
energy = (out.float()**2).sum() / out.numel()
force = torch.autograd.grad(
energy,
self.input1_float,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
loss = (out.float()**2).sum() / out.numel() + (force.float()**2).sum()
loss.backward()
out_ = self.input1_float_[self.index1] * self.input2_float_
energy_ = (out_.float()**2).sum() / out.numel()
force_ = torch.autograd.grad(
energy_,
self.input1_float_,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum()
loss.backward()
self.assertTrue(torch.allclose(self.input1_float, self.input1_float_, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input2_float, self.input2_float_, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input1_float.grad, self.input1_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input2_float.grad, self.input2_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
def test_index_mul_half(self):
out = index_mul_2d(self.input1_half, self.input2_half, self.index1)
energy = (out.float()**2).sum() / out.numel()
force = torch.autograd.grad(
energy,
self.input1_half,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
loss = (out.float()**2).sum() / out.numel() + (force.float()**2).sum()
loss.backward()
out_ = self.input1_half_[self.index1] * self.input2_half_
energy_ = (out_.float()**2).sum() / out.numel()
force_ = torch.autograd.grad(
energy_,
self.input1_half_,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum()
loss.backward()
self.assertTrue(torch.allclose(self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input1_half.grad, self.input1_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(self.input2_half.grad, self.input2_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
if __name__ == '__main__':
unittest.main()
......@@ -307,6 +307,23 @@ if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--index_mul_2d" in sys.argv:
if "--index_mul_2d" in sys.argv:
sys.argv.remove("--index_mul_2d")
ext_modules.append(
CUDAExtension(
name='fused_index_mul_2d',
sources=[
'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp',
'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu',
],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc':(['-O3', '--use_fast_math', '--ftz=false'] if not IS_ROCM_PYTORCH else ['-O3']) + version_dependent_macros,
},
)
)
if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment