knlist.br 5.99 KB
Newer Older
Mark Friedrichs's avatar
Mark Friedrichs committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

/****************************************************************
* This file is part of the gpu acceleration library for gromacs.
* Author: V. Vishal
* Copyright (C) Pande Group, Stanford, 2006
*****************************************************************/

/* Order N^2 neighbor searching.
 *
 * This only works for force fields that don't have charge groups.
 * If you insist on charge groups, you'll have to pass in appropriate masks here.
 *
 * This is a simplified kernel, for testing the O(N) speeds.
 *
 * This does a complete N^2 search without considering groups of
 * atoms. Most likely this will prove to be inefficient for 
 * the O(N) kernel. Lets find out.
 * 
 *
 * Each component of the curpass textures is an atom index. The w component
 * of curpass3 is a count indicating how many j particles we have
 * scanned for this particular i atom.
 * 
 * */

kernel void knborsearch(
		float first,          //Positive means constructing the first 16.
		iter float2 wpos<>,   //pixel position of output
		float AtomStrHeight,
		float AtomStrWidth,
		float cutoff2,      //square of the cutoff
		float natoms,       //number of atoms
		float excl[][],    //exclusions in 1x1 format, 0 means not excluded, 1 means excluded.
		float4 posq[][],    //atom positions/charges
		float4 prevpass3<>, //Last output texture of previous pass
		out float4 curpass0<>,  //First output of current pass
		out float4 curpass1<>,  
		out float4 curpass2<>,  
		out float4 curpass3<>  //Last output of current pass, used in next pass
		){
	/*For this kernel, wpos == iatom*/
	float2 iind;
	float2 jind;
	float3 ipos, jpos, dr;
	float r2;
	float listptr; //Where in the 16-chunk are we now.
	float jlinind;
	float breakflag; //positive means keep looping, negative means stop
	float4 exclconst;
	float2 exclind;
	float exclusions;

	exclconst = float4( 2.0f, 3.0f, 5.0f, 7.0f );
	
	iind = wpos;

	exclind.x = iind.x + iind.y * AtomStrWidth;

	//etch i atom
	ipos = posq[ iind ].xyz;

	//Loop over j depending on prevpass
	jlinind = prevpass3.w + 1;
	jind.y = floor( jlinind / AtomStrWidth );
	jind.x = fmod( jlinind, AtomStrWidth );
	exclind.y = jlinind;

	//All outputs should be initialized to 

	listptr = 0.0f;
	breakflag = 1.0f;

	//if we already finished, do nothing
	if ( first < 0.0f && prevpass3.w < 0.0f )
		breakflag = -1.0f;

	//set to -1 to indicate no neighbor
	//just to save a separate set of init calls
	curpass0 = float4( -1.0f, -1.0f, -1.0f, -1.0f );
	curpass1 = curpass0;
	curpass2 = curpass0;
	curpass3 = curpass0;
	
	while ( jind.y < AtomStrHeight && breakflag > 0.0f ) {
		while ( jind.x < AtomStrWidth && breakflag > 0.0f ) {

			//First see if this pair is excluded
			exclusions = excl[ exclind ];

			if ( exclusions < 0.5f ) {
				jpos = posq[ jind ].xyz;

				dr = jpos - ipos;
				r2 = dot( dr, dr );

				//If it is inside the cutoff
				if ( r2 < cutoff2 ) {
					//Figure out where to put it
					//We are allowed 4 nested conditionals
					//We can play with the structuring of these
					if ( listptr < 0.5f )
						curpass0.x = jlinind;
					else if ( listptr < 1.5f )
						curpass0.y = jlinind;
					else if ( listptr < 2.5f )
						curpass0.z = jlinind;
					else if ( listptr < 3.5f )
						curpass0.w = jlinind;
					else if ( listptr < 4.5f )
						curpass1.x = jlinind;
					else if ( listptr < 5.5f )
						curpass1.y = jlinind;
					else if ( listptr < 6.5f )
						curpass1.z = jlinind;
					else if ( listptr < 7.5f )
						curpass1.w = jlinind;
					else if ( listptr < 8.5f )
						curpass2.x = jlinind;
					else if ( listptr < 9.5f )
						curpass2.y = jlinind;
					else if ( listptr < 10.5f )
						curpass2.z = jlinind;
					else if ( listptr < 11.5f )
						curpass2.w = jlinind;
					else if ( listptr < 12.5f )
						curpass3.x = jlinind;
					else if ( listptr < 13.5f )
						curpass3.y = jlinind;
					else if ( listptr < 14.5f ) {
						curpass3.z = jlinind;
					}
					else if ( listptr < 15.5f ) {
						//We're done for this pass
						curpass3.w = jlinind;
						breakflag = -1.0f;
					}
					listptr += 1.0f;
				}
			}

			jlinind += 1.0f;
			exclind.y += 1.0f;
			jind.x  += 1.0f;
		}
		jind.x = 0.0f;
		jind.y += 1.0f;
	}
	
}

//Precomputes lennard jones sig and eps 
//to save an indirect etch (and a ew flops) in the 
//force kernel. The charge product is not done this way
//because charges have to be etched anyway with the 
//positions
kernel void knl_precompute_sigeps(
		float AtomStrWidth,
		iter float2 wpos<>,
		float2 sigeps[][], //x=sigma, y=epsilon
		float4 nlist0<>,
		float4 nlist1<>,
		out float4 sig0<>,
		out float4 eps0<>,
		out float4 sig1<>,
		out float4 eps1<>
	   	)
{
	float2 jind;
	float4 ind_tmp1, ind_tmp2;
	float2 isigeps, jsigeps1, jsigeps2, jsigeps3, jsigeps4;
	
	isigeps = sigeps[ wpos ];

	ind_tmp1 = floor( nlist0 / AtomStrWidth );
	ind_tmp2 = nlist0 - ind_tmp1 * AtomStrWidth;

	jind.y = ind_tmp1.x;
	jind.x = ind_tmp2.x;
	jsigeps1 = sigeps[ jind ];
	
	jind.y = ind_tmp1.y;
	jind.x = ind_tmp2.y;
	jsigeps2 = sigeps[ jind ];
	
	jind.y = ind_tmp1.z;
	jind.x = ind_tmp2.z;
	jsigeps3 = sigeps[ jind ];
	
	jind.y = ind_tmp1.w;
	jind.x = ind_tmp2.w;
	jsigeps4 = sigeps[ jind ];

	sig0.x = isigeps.x + jsigeps1.x;
	sig0.y = isigeps.x + jsigeps2.x;
	sig0.z = isigeps.x + jsigeps3.x;
	sig0.w = isigeps.x + jsigeps4.x;
	
	eps0.x = isigeps.y * jsigeps1.y;
	eps0.y = isigeps.y * jsigeps2.y;
	eps0.z = isigeps.y * jsigeps3.y;
	eps0.w = isigeps.y * jsigeps4.y;

	//2nd nlist set
	
	ind_tmp1 = floor( nlist1 / AtomStrWidth );
	ind_tmp2 = nlist1 - ind_tmp1 * AtomStrWidth;

	jind.y = ind_tmp1.x;
	jind.x = ind_tmp2.x;
	jsigeps1 = sigeps[ jind ];
	
	jind.y = ind_tmp1.y;
	jind.x = ind_tmp2.y;
	jsigeps2 = sigeps[ jind ];
	
	jind.y = ind_tmp1.z;
	jind.x = ind_tmp2.z;
	jsigeps3 = sigeps[ jind ];
	
	jind.y = ind_tmp1.w;
	jind.x = ind_tmp2.w;
	jsigeps4 = sigeps[ jind ];

	

	sig1.x = isigeps.x + jsigeps1.x;
	sig1.y = isigeps.x + jsigeps2.x;
	sig1.z = isigeps.x + jsigeps3.x;
	sig1.w = isigeps.x + jsigeps4.x;
	
	eps1.x = isigeps.y * jsigeps1.y;
	eps1.y = isigeps.y * jsigeps2.y;
	eps1.z = isigeps.y * jsigeps3.y;
	eps1.w = isigeps.y * jsigeps4.y;
	
	
}