encoding_kernel.c 4.1 KB
Newer Older
Hang Zhang's avatar
init  
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 * Created by: Hang Zhang
 * ECE Department, Rutgers University
 * Email: zhang.hang@rutgers.edu
 * Copyright (c) 2017
 *
 * This source code is licensed under the MIT-style license found in the
 * LICENSE file in the root directory of this source tree 
 *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 */
#ifndef THC_GENERIC_FILE
Hang Zhang's avatar
tested  
Hang Zhang committed
12
#define THC_GENERIC_FILE "generic/encoding_kernel.c"
Hang Zhang's avatar
init  
Hang Zhang committed
13
14
15
16
17
18
#else

__global__ void Encoding_(Aggregate_Forward_kernel) (
	THCDeviceTensor<real, 3> E,
	THCDeviceTensor<real, 3> A,
	THCDeviceTensor<real, 4> R)
Hang Zhang's avatar
backend  
Hang Zhang committed
19
/*
Hang Zhang's avatar
Hang Zhang committed
20
 * aggregating forward kernel function
Hang Zhang's avatar
backend  
Hang Zhang committed
21
 */
Hang Zhang's avatar
init  
Hang Zhang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
{
  /* declarations of the variables */
  int b, k, d, i, N;
	real sum;
  /* Get the index and channels */ 
  b = blockIdx.z;
  d = blockIdx.x * blockDim.x + threadIdx.x;
  k = blockIdx.y * blockDim.y + threadIdx.y;
	N = A.getSize(1);
	/* boundary check for output */
	sum = 0;
	if (d >= E.getSize(2) || k >= E.getSize(1))	return;
	/* main operation */
	for(i=0; i<N; i++) {
		sum += A[b][i][k].ldg() * R[b][i][k][d].ldg();
	}
	E[b][k][d] = sum;
}

Hang Zhang's avatar
backend  
Hang Zhang committed
41
42
void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_, 
							THCTensor *A_, THCTensor *R_)
Hang Zhang's avatar
init  
Hang Zhang committed
43
/*
Hang Zhang's avatar
Hang Zhang committed
44
 * aggregating forward the residuals with assignment weights
Hang Zhang's avatar
init  
Hang Zhang committed
45
46
 */
{
Hang Zhang's avatar
backend  
Hang Zhang committed
47
	/* Check the GPU index and tensor dims*/
Hang Zhang's avatar
init  
Hang Zhang committed
48
49
50
51
	THCTensor_(checkGPU)(state, 3, E_, A_, R_);
	if (THCTensor_(nDimension)(state, E_) != 3 ||
			THCTensor_(nDimension)(state, A_) != 3 ||
			THCTensor_(nDimension)(state, R_) != 4)
Hang Zhang's avatar
tested  
Hang Zhang committed
52
		THError("Encoding: incorrect input dims. \n");
Hang Zhang's avatar
init  
Hang Zhang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
	/* Device tensors */
	THCDeviceTensor<real, 3> E = devicetensor<3>(state, E_);
	THCDeviceTensor<real, 3> A = devicetensor<3>(state, A_);
	THCDeviceTensor<real, 4> R = devicetensor<4>(state, R_);
	/* kernel function */
	cudaStream_t stream = THCState_getCurrentStream(state);
	dim3 threads(16, 16);
	dim3 blocks(E.getSize(2)/16+1, E.getSize(1)/16+1, 
							E.getSize(0));
	Encoding_(Aggregate_Forward_kernel)<<<blocks, threads, 0, stream>>>(E, A, R);
	THCudaCheck(cudaGetLastError());
}

Hang Zhang's avatar
Hang Zhang committed
66
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
Hang Zhang's avatar
backend  
Hang Zhang committed
67
__global__ void Encoding_(Aggregate_Backward_kernel) (
Hang Zhang's avatar
Hang Zhang committed
68
69
	THCDeviceTensor<real, 3> GA,
	THCDeviceTensor<real, 4> GR,
Hang Zhang's avatar
backend  
Hang Zhang committed
70
	THCDeviceTensor<real, 3> L,
Hang Zhang's avatar
Hang Zhang committed
71
	THCDeviceTensor<real, 3> A,
Hang Zhang's avatar
backend  
Hang Zhang committed
72
73
74
	THCDeviceTensor<real, 4> R)
/*
 * aggregating backward kernel function
Hang Zhang's avatar
Hang Zhang committed
75
 * G (dl/dR), L (dl/dE), A
Hang Zhang's avatar
backend  
Hang Zhang committed
76
77
78
79
80
81
82
83
 */
{
  /* declarations of the variables */
  int b, k, d, i, D;
	real sum;
  /* Get the index and channels */ 
  b = blockIdx.z;
  i = blockIdx.y * blockDim.y + threadIdx.y;
Hang Zhang's avatar
Hang Zhang committed
84
  k = blockIdx.x * blockDim.x + threadIdx.x;
Hang Zhang's avatar
backend  
Hang Zhang committed
85
	D = L.getSize(2);
Hang Zhang's avatar
Hang Zhang committed
86
87
	/* boundary check for output G \in R^{BxNxKxD} */
	if (k >= GR.getSize(2) || i >= GR.getSize(1))	return;
Hang Zhang's avatar
backend  
Hang Zhang committed
88
89
90
	/* main operation */
	sum = 0;
	for(d=0; d<D; d++) {
Hang Zhang's avatar
Hang Zhang committed
91
92
		//sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
		GR[b][i][k][d] = L[b][k][d] * A[b][i][k];
Hang Zhang's avatar
backend  
Hang Zhang committed
93
94
		sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
	}
Hang Zhang's avatar
Hang Zhang committed
95
	GA[b][i][k] = sum;
Hang Zhang's avatar
backend  
Hang Zhang committed
96
97
}

Hang Zhang's avatar
Hang Zhang committed
98
99
void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *GA_, 
 	THCTensor *GR_, THCTensor *L_, THCTensor *A_, THCTensor *R_)
Hang Zhang's avatar
backend  
Hang Zhang committed
100
101
/*
 * aggregate backward to assignment weights
Hang Zhang's avatar
Hang Zhang committed
102
 * G (dl/dR), L (dl/dE), A
Hang Zhang's avatar
backend  
Hang Zhang committed
103
104
105
 */
{
	/* Check the GPU index and tensor dims*/
Hang Zhang's avatar
Hang Zhang committed
106
107
108
109
110
111
	THCTensor_(checkGPU)(state, 5, GA_, GR_, L_, A_, R_);
	if (THCTensor_(nDimension)(state, GA_) != 3 ||
			THCTensor_(nDimension)(state, GR_) != 4 ||
			THCTensor_(nDimension)(state, L_)  != 3 ||
			THCTensor_(nDimension)(state, A_)  != 3 ||
			THCTensor_(nDimension)(state, R_)  != 4)
Hang Zhang's avatar
backend  
Hang Zhang committed
112
113
		THError("Encoding: incorrect input dims. \n");
	/* Device tensors */
Hang Zhang's avatar
Hang Zhang committed
114
115
	THCDeviceTensor<real, 3> GA = devicetensor<3>(state, GA_);
	THCDeviceTensor<real, 4> GR = devicetensor<4>(state, GR_);
Hang Zhang's avatar
backend  
Hang Zhang committed
116
	THCDeviceTensor<real, 3> L = devicetensor<3>(state, L_);
Hang Zhang's avatar
Hang Zhang committed
117
	THCDeviceTensor<real, 3> A = devicetensor<3>(state, A_);
Hang Zhang's avatar
backend  
Hang Zhang committed
118
119
120
121
	THCDeviceTensor<real, 4> R = devicetensor<4>(state, R_);
	/* kernel function */
	cudaStream_t stream = THCState_getCurrentStream(state);
	dim3 threads(16, 16);
Hang Zhang's avatar
Hang Zhang committed
122
123
124
125
	dim3 blocks(GA.getSize(2)/16+1, GA.getSize(1)/16+1, 
							GA.getSize(0));
	Encoding_(Aggregate_Backward_kernel)<<<blocks, threads, 0, stream>>>(GA,
					GR, L, A, R);
Hang Zhang's avatar
backend  
Hang Zhang committed
126
127
	THCudaCheck(cudaGetLastError());
}
Hang Zhang's avatar
init  
Hang Zhang committed
128
#endif