"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "1a47fa08adb9d1aafd8454c43c0265f952ccbda0"
encoding_kernel.c 3.53 KB
Newer Older
Hang Zhang's avatar
init  
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 * Created by: Hang Zhang
 * ECE Department, Rutgers University
 * Email: zhang.hang@rutgers.edu
 * Copyright (c) 2017
 *
 * This source code is licensed under the MIT-style license found in the
 * LICENSE file in the root directory of this source tree 
 *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 */
#ifndef THC_GENERIC_FILE
Hang Zhang's avatar
tested  
Hang Zhang committed
12
#define THC_GENERIC_FILE "generic/encoding_kernel.c"
Hang Zhang's avatar
init  
Hang Zhang committed
13
14
15
16
17
18
#else

__global__ void Encoding_(Aggregate_Forward_kernel) (
	THCDeviceTensor<real, 3> E,
	THCDeviceTensor<real, 3> A,
	THCDeviceTensor<real, 4> R)
Hang Zhang's avatar
backend  
Hang Zhang committed
19
20
21
/*
 * aggregating kernel function
 */
Hang Zhang's avatar
init  
Hang Zhang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
{
  /* declarations of the variables */
  int b, k, d, i, N;
	real sum;
  /* Get the index and channels */ 
  b = blockIdx.z;
  d = blockIdx.x * blockDim.x + threadIdx.x;
  k = blockIdx.y * blockDim.y + threadIdx.y;
	N = A.getSize(1);
	/* boundary check for output */
	sum = 0;
	if (d >= E.getSize(2) || k >= E.getSize(1))	return;
	/* main operation */
	for(i=0; i<N; i++) {
		sum += A[b][i][k].ldg() * R[b][i][k][d].ldg();
	}
	E[b][k][d] = sum;
}

Hang Zhang's avatar
backend  
Hang Zhang committed
41
42
void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_, 
							THCTensor *A_, THCTensor *R_)
Hang Zhang's avatar
init  
Hang Zhang committed
43
44
45
46
/*
 * aggregating the residuals with assignment weights
 */
{
Hang Zhang's avatar
backend  
Hang Zhang committed
47
	/* Check the GPU index and tensor dims*/
Hang Zhang's avatar
init  
Hang Zhang committed
48
49
50
51
	THCTensor_(checkGPU)(state, 3, E_, A_, R_);
	if (THCTensor_(nDimension)(state, E_) != 3 ||
			THCTensor_(nDimension)(state, A_) != 3 ||
			THCTensor_(nDimension)(state, R_) != 4)
Hang Zhang's avatar
tested  
Hang Zhang committed
52
		THError("Encoding: incorrect input dims. \n");
Hang Zhang's avatar
init  
Hang Zhang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
	/* Device tensors */
	THCDeviceTensor<real, 3> E = devicetensor<3>(state, E_);
	THCDeviceTensor<real, 3> A = devicetensor<3>(state, A_);
	THCDeviceTensor<real, 4> R = devicetensor<4>(state, R_);
	/* kernel function */
	cudaStream_t stream = THCState_getCurrentStream(state);
	dim3 threads(16, 16);
	dim3 blocks(E.getSize(2)/16+1, E.getSize(1)/16+1, 
							E.getSize(0));
	Encoding_(Aggregate_Forward_kernel)<<<blocks, threads, 0, stream>>>(E, A, R);
	THCudaCheck(cudaGetLastError());
}

Hang Zhang's avatar
backend  
Hang Zhang committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
__global__ void Encoding_(Aggregate_Backward_kernel) (
	THCDeviceTensor<real, 3> G,
	THCDeviceTensor<real, 3> L,
	THCDeviceTensor<real, 4> R)
/*
 * aggregating backward kernel function
 */
{
  /* declarations of the variables */
  int b, k, d, i, D;
	real sum;
  /* Get the index and channels */ 
  b = blockIdx.z;
  k = blockIdx.x * blockDim.x + threadIdx.x;
  i = blockIdx.y * blockDim.y + threadIdx.y;
	D = L.getSize(2);
	/* boundary check for output */
	if (k >= G.getSize(2) || i >= G.getSize(1))	return;
	/* main operation */
	sum = 0;
	for(d=0; d<D; d++) {
		sum += L[b][k][d].ldg() * R[b][i][k][d].ldg();
	}
	G[b][i][k] = sum;
}

void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *G_, 
							THCTensor *L_, THCTensor *R_)
/*
 * aggregate backward to assignment weights
 */
{
	/* Check the GPU index and tensor dims*/
	THCTensor_(checkGPU)(state, 3, G_, L_, R_);
	if (THCTensor_(nDimension)(state, G_) != 3 ||
			THCTensor_(nDimension)(state, L_) != 3 ||
			THCTensor_(nDimension)(state, R_) != 4)
		THError("Encoding: incorrect input dims. \n");
	/* Device tensors */
	THCDeviceTensor<real, 3> G = devicetensor<3>(state, G_);
	THCDeviceTensor<real, 3> L = devicetensor<3>(state, L_);
	THCDeviceTensor<real, 4> R = devicetensor<4>(state, R_);
	/* kernel function */
	cudaStream_t stream = THCState_getCurrentStream(state);
	dim3 threads(16, 16);
	dim3 blocks(G.getSize(2)/16+1, G.getSize(1)/16+1, 
							G.getSize(0));
	Encoding_(Aggregate_Backward_kernel)<<<blocks, threads, 0, stream>>>(G, L, R);
	THCudaCheck(cudaGetLastError());
}

Hang Zhang's avatar
init  
Hang Zhang committed
117
#endif