#include #include #include #include #include #include "common.h" #include "device_tensor.h" namespace { template struct KD2Op { __device__ KD2Op(DeviceTensor x, DeviceTensor c, DeviceTensor std) : X(x), C(c), STD(std) {} __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { DType r = (X[b][i][d] - C[k][d]) / STD[k][d]; return ScalarConvert::to(r * r); } DeviceTensor X; DeviceTensor C; DeviceTensor STD; }; template __global__ void Encoding_Dist_Forward_kernel ( DeviceTensor KD, DeviceTensor X, DeviceTensor C, DeviceTensor STD) { /* declarations of the variables */ int b, k, i, D; /* Get the index and channels */ b = blockIdx.z; k = blockIdx.x; i = blockIdx.y; D = X.getSize(2); /* main operation */ KD2Op g(X, C, STD); KD[b][i][k] = reduceD(g, b, i, k, D);; } template struct EncGradXOp { __device__ EncGradXOp( DeviceTensor gkd, DeviceTensor x, DeviceTensor c, DeviceTensor std) : GKD(gkd), X(x), C(c), STD(std) {} // DeviceTensor s, S(s) __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { return ScalarConvert::to( 2 * GKD[b][i][k] * (X[b][i][d] - C[k][d]) / (STD[k][d] * STD[k][d])); } DeviceTensor GKD; DeviceTensor X; DeviceTensor C; DeviceTensor STD; // DeviceTensor S; }; template __global__ void Encoding_GradX_kernel ( DeviceTensor GKD, DeviceTensor GX, DeviceTensor X, DeviceTensor C, DeviceTensor STD) { // DeviceTensor S /* declarations of the variables */ int b, d, i, K; /* Get the index and channels */ b = blockIdx.z; i = blockIdx.y; d = blockIdx.x; K = C.getSize(0); /* main operation */ EncGradXOp g(GKD, X, C, STD); GX[b][i][d] = reduceK(g, b, i, d, K); } template struct EncGradSTDOp { __device__ EncGradSTDOp( DeviceTensor gkd, DeviceTensor x, DeviceTensor c, DeviceTensor std) : GKD(gkd), X(x), C(c), STD(std) {} // DeviceTensor s, S(s) __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { return ScalarConvert::to( -2 * GKD[b][i][k] * (X[b][i][d] - C[k][d]) * (X[b][i][d] - C[k][d]) / (STD[k][d] * STD[k][d] * STD[k][d])); } DeviceTensor GKD; DeviceTensor X; DeviceTensor C; DeviceTensor STD; // DeviceTensor S; }; template __global__ void Encoding_GradCSTD_kernel ( DeviceTensor GKD, DeviceTensor GC, DeviceTensor GSTD, DeviceTensor X, DeviceTensor C, DeviceTensor STD) { /* declarations of the variables */ int k, d, B, N; /* Get the index and channels */ d = blockIdx.x; k = blockIdx.y; B = X.getSize(0); N = X.getSize(1); /* main operation */ EncGradXOp g1(GKD, X, C, STD); EncGradSTDOp g2(GKD, X, C, STD); GC[k][d] = -reduceBN(g1, k, d, B, N); GSTD[k][d] += reduceBN(g2, k, d, B, N); } template struct EncGradSTDXOp { __device__ EncGradSTDXOp( DeviceTensor gstd, DeviceTensor x, DeviceTensor c, DeviceTensor std) : GSTD(gstd), X(x), C(c), STD(std) {} __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { return ScalarConvert::to( GSTD[k][d] * (X[b][i][d] - C[k][d]) / STD[k][d]); } DeviceTensor GSTD; DeviceTensor X; DeviceTensor C; DeviceTensor STD; }; template __global__ void Encoding_GradSTDX_kernel ( DeviceTensor GSTD, DeviceTensor GX, DeviceTensor X, DeviceTensor C, DeviceTensor STD, int N) { /* declarations of the variables */ int b, d, i, K; /* Get the index and channels */ b = blockIdx.z; i = blockIdx.y; d = blockIdx.x; K = C.getSize(0); /* main operation */ EncGradSTDXOp g(GSTD, X, C, STD); GX[b][i][d] += reduceK(g, b, i, d, K) / N; } template struct AggOpV2 { __device__ AggOpV2(DeviceTensor a, DeviceTensor x, DeviceTensor c, DeviceTensor std) : A(a), X(x), C(c), STD(std) {} __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { return ScalarConvert::to(A[b][i][k] * (X[b][i][d] - C[k][d]) / STD[k][d]); } DeviceTensor A; DeviceTensor X; DeviceTensor C; DeviceTensor STD; }; template __global__ void AggregateV2_Forward_kernel ( DeviceTensor E, DeviceTensor A, DeviceTensor X, DeviceTensor C, DeviceTensor STD) { /* declarations of the variables */ int b, k, d, N; /* Get the index and channels */ b = blockIdx.z; d = blockIdx.x; k = blockIdx.y; N = X.getSize(1); /* main operation */ AggOpV2 g(A, X, C, STD); E[b][k][d] = reduceN(g, b, k, d, N); } template struct AggV2BackOp { __device__ AggV2BackOp(DeviceTensor g, DeviceTensor x, DeviceTensor c, DeviceTensor std) : G(g), X(x), C(c), STD(std) {} __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { return ScalarConvert::to(G[b][k][d] * (X[b][i][d] - C[k][d]) / STD[k][d]); } DeviceTensor G; DeviceTensor X; DeviceTensor C; DeviceTensor STD; }; template __global__ void AggregateV2_Backward_kernel ( DeviceTensor GA, DeviceTensor GE, DeviceTensor A, DeviceTensor X, DeviceTensor C, DeviceTensor STD) { /* declarations of the variables */ int b, k, i, D; /* Get the index and channels */ b = blockIdx.z; i = blockIdx.y; k = blockIdx.x; D = GE.getSize(2); /* main operation */ AggV2BackOp g(GE, X, C, STD); GA[b][i][k] = reduceD(g, b, i, k, D); } } // namespace at::Tensor Encoding_Dist_Inference_Forward_CUDA( const at::Tensor X_, const at::Tensor C_, const at::Tensor STD_) { // const at::Tensor S_, // X \in R^{B, N, D}, C \in R^{K, D}, S \in R^K auto KD_ = torch::zeros({X_.size(0), X_.size(1), C_.size(0)}, X_.options()); // E(x), E(x^2) int N = X_.size(0) * X_.size(1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(C_.size(0), X_.size(1), X_.size(0)); dim3 threads(getNumThreads(C_.size(1))); // calculate the kernel distance AT_DISPATCH_FLOATING_TYPES(X_.type(), "Encoding_Dist_Inference_Forward_CUDA", ([&] { /* Device tensors */ DeviceTensor KD = devicetensor(KD_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); DeviceTensor STD = devicetensor(STD_); /* kernel function */ Encoding_Dist_Forward_kernel <<>> (KD, X, C, STD); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return KD_; } std::vector Encoding_Dist_Inference_Backward_CUDA( const at::Tensor GKD_, const at::Tensor KD_, const at::Tensor X_, const at::Tensor C_, const at::Tensor STD_) { auto GX_ = at::zeros_like(X_); auto GC_ = at::zeros_like(C_); auto GSTD_ = at::zeros_like(STD_); /* kernel function */ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 blocks1(X_.size(2), X_.size(1), X_.size(0)); dim3 threads1(getNumThreads(C_.size(0))); dim3 blocks2(C_.size(1), C_.size(0)); dim3 threads2(getNumThreads(X_.size(1))); int N = X_.size(0) * X_.size(1); AT_DISPATCH_FLOATING_TYPES(X_.type(), "Encoding_Dist_Backward_CUDA", ([&] { /* Device tensors */ DeviceTensor GKD = devicetensor(GKD_); DeviceTensor GSTD = devicetensor(GSTD_); DeviceTensor GX = devicetensor(GX_); DeviceTensor GC = devicetensor(GC_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); DeviceTensor STD = devicetensor(STD_); Encoding_GradX_kernel <<>> (GKD, GX, X, C, STD); AT_ASSERT(cudaGetLastError() == cudaSuccess); Encoding_GradCSTD_kernel <<>> (GKD, GC, GSTD, X, C, STD); AT_ASSERT(cudaGetLastError() == cudaSuccess); })); return {GX_, GC_, GSTD_}; } std::vector Encoding_Dist_Forward_CUDA( const at::Tensor X_, const at::Tensor C_, double eps) { // const at::Tensor S_, // X \in R^{B, N, D}, C \in R^{K, D}, S \in R^K auto KD_ = torch::zeros({X_.size(0), X_.size(1), C_.size(0)}, X_.options()); // E(x), E(x^2) int N = X_.size(0) * X_.size(1); auto SVar_ = (X_.pow(2).sum(0).sum(0).view({1, X_.size(2)}) - 2 * C_ * X_.sum(0).sum(0).view({1, X_.size(2)})).expand_as(C_) + C_.pow(2) * N; auto STD_ = at::sqrt(SVar_ / N + eps); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(C_.size(0), X_.size(1), X_.size(0)); dim3 threads(getNumThreads(C_.size(1))); // calculate the kernel distance AT_DISPATCH_FLOATING_TYPES(X_.type(), "Encoding_Dist_Forward_CUDA", ([&] { /* Device tensors */ DeviceTensor KD = devicetensor(KD_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); DeviceTensor STD = devicetensor(STD_); /* kernel function */ Encoding_Dist_Forward_kernel <<>> (KD, X, C, STD); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return {KD_, STD_, SVar_ / (N - 1)}; } std::vector Encoding_Dist_Backward_CUDA( const at::Tensor GKD_, const at::Tensor GSTD_, const at::Tensor KD_, const at::Tensor X_, const at::Tensor C_, const at::Tensor STD_) { auto GX_ = at::zeros_like(X_); auto GC_ = at::zeros_like(C_); auto GSTD2_ = GSTD_.clone(); /* kernel function */ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 blocks1(X_.size(2), X_.size(1), X_.size(0)); dim3 threads1(getNumThreads(C_.size(0))); dim3 blocks2(C_.size(1), C_.size(0)); dim3 threads2(getNumThreads(X_.size(1))); int N = X_.size(0) * X_.size(1); AT_DISPATCH_FLOATING_TYPES(X_.type(), "Encoding_Dist_Backward_CUDA", ([&] { /* Device tensors */ DeviceTensor GKD = devicetensor(GKD_); DeviceTensor GSTD = devicetensor(GSTD2_); DeviceTensor GX = devicetensor(GX_); DeviceTensor GC = devicetensor(GC_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); DeviceTensor STD = devicetensor(STD_); Encoding_GradX_kernel <<>> (GKD, GX, X, C, STD); AT_ASSERT(cudaGetLastError() == cudaSuccess); Encoding_GradCSTD_kernel <<>> (GKD, GC, GSTD, X, C, STD); AT_ASSERT(cudaGetLastError() == cudaSuccess); Encoding_GradSTDX_kernel <<>> (GSTD, GX, X, C, STD, N); AT_ASSERT(cudaGetLastError() == cudaSuccess); })); // d_sigma/d_c GC_ = GC_ - GSTD2_ * (X_.mean(0).mean(0) - C_) / STD_; return {GX_, GC_}; } at::Tensor AggregateV2_Forward_CUDA( const at::Tensor A_, const at::Tensor X_, const at::Tensor C_, const at::Tensor STD_) { /* Device tensors */ auto E_ = torch::zeros({A_.size(0), C_.size(0), C_.size(1)}, A_.options()); // auto IS_ = 1.0f / (S_ + eps).sqrt(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // B, K, D dim3 blocks(C_.size(1), C_.size(0), X_.size(0)); dim3 threads(getNumThreads(X_.size(1))); AT_DISPATCH_FLOATING_TYPES(A_.type(), "Aggregate_Forward_CUDA", ([&] { DeviceTensor E = devicetensor(E_); DeviceTensor A = devicetensor(A_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); DeviceTensor STD = devicetensor(STD_); /* kernel function */ AggregateV2_Forward_kernel <<>>(E, A, X, C, STD); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return E_; } std::vector AggregateV2_Backward_CUDA( const at::Tensor GE_, const at::Tensor E_, const at::Tensor A_, const at::Tensor X_, const at::Tensor C_, const at::Tensor STD_) { auto gradA_ = at::zeros_like(A_); auto gradX_ = at::bmm(A_ , (GE_ / STD_.unsqueeze(0))); auto gradC_ = -(A_.sum(1).unsqueeze(2) * GE_ / STD_.unsqueeze(0)).sum(0); auto gradSTD_ = -(GE_ * E_).sum(0) / STD_; // auto gradS_ = -0.5 * (GE_ * E_).sum(2).sum(0) / S_; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // B, K, D dim3 blocks(C_.size(0), X_.size(1), X_.size(0)); dim3 threads(getNumThreads(C_.size(1))); AT_DISPATCH_FLOATING_TYPES(A_.type(), "Aggregate_Backward_CUDA", ([&] { /* Device tensors */ DeviceTensor GA = devicetensor(gradA_); DeviceTensor GE = devicetensor(GE_); DeviceTensor A = devicetensor(A_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); DeviceTensor STD = devicetensor(STD_); AggregateV2_Backward_kernel <<>> (GA, GE, A, X, C, STD); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return {gradA_, gradX_, gradC_, gradSTD_}; }