Commit 54b7eec9 authored by Peter Eastman's avatar Peter Eastman
Browse files

Reduced memory use for random numbers, eliminated unnecessary calculation when...

Reduced memory use for random numbers, eliminated unnecessary calculation when generating random numbers
parent 9b083022
......@@ -798,7 +798,7 @@ void CudaIntegrationUtilities::initRandomNumberGenerator(unsigned int randomNumb
// Create the random number arrays.
lastSeed = randomNumberSeed;
random = CudaArray::create<float4>(context, 32*context.getPaddedNumAtoms(), "random");
random = CudaArray::create<float4>(context, 4*context.getPaddedNumAtoms(), "random");
randomSeed = CudaArray::create<int4>(context, context.getNumThreadBlocks()*CudaContext::ThreadBlockSize, "randomSeed");
randomPos = random->getSize();
......
......@@ -8,7 +8,7 @@ extern "C" __global__ void generateRandomNumbers(int numValues, float4* __restri
while (index < numValues) {
float4 value;
// Generate first value.
// Generate first two values.
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
......@@ -32,8 +32,9 @@ extern "C" __global__ void generateRandomNumbers(int numValues, float4* __restri
carry = k >> 30;
float x2 = (float)(state.x + state.y + state.w) / (float)0xffffffff;
value.x = x1 * COS(2.0f * 3.14159265f * x2);
value.y = x1 * SIN(2.0f * 3.14159265f * x2);
// Generate second value.
// Generate next two values.
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
......@@ -56,57 +57,8 @@ extern "C" __global__ void generateRandomNumbers(int numValues, float4* __restri
state.w = m;
carry = k >> 30;
float x4 = (float)(state.x + state.y + state.w) / (float)0xffffffff;
value.y = x3 * COS(2.0f * 3.14159265f * x4);
// Generate third value.
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
state.y ^= state.y >> 17;
state.y ^= state.y << 5;
k = (state.z >> 2) + (state.w >> 3) + (carry >> 2);
m = state.w + state.w + state.z + carry;
state.z = state.w;
state.w = m;
carry = k >> 30;
float x5 = (float)max(state.x + state.y + state.w, 0x00000001u) / (float)0xffffffff;
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
state.y ^= state.y >> 17;
state.y ^= state.y << 5;
x5 = SQRT(-2.0f * LOG(x5));
k = (state.z >> 2) + (state.w >> 3) + (carry >> 2);
m = state.w + state.w + state.z + carry;
state.z = state.w;
state.w = m;
carry = k >> 30;
float x6 = (float)(state.x + state.y + state.w) / (float)0xffffffff;
value.z = x5 * COS(2.0f * 3.14159265f * x6);
// Generate fourth value.
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
state.y ^= state.y >> 17;
state.y ^= state.y << 5;
k = (state.z >> 2) + (state.w >> 3) + (carry >> 2);
m = state.w + state.w + state.z + carry;
state.z = state.w;
state.w = m;
carry = k >> 30;
float x7 = (float)max(state.x + state.y + state.w, 0x00000001u) / (float)0xffffffff;
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
state.y ^= state.y >> 17;
state.y ^= state.y << 5;
x7 = SQRT(-2.0f * LOG(x7));
k = (state.z >> 2) + (state.w >> 3) + (carry >> 2);
m = state.w + state.w + state.z + carry;
state.z = state.w;
state.w = m;
carry = k >> 30;
float x8 = (float)(state.x + state.y + state.w) / (float)0xffffffff;
value.w = x7 * COS(2.0f * 3.14159265f * x8);
value.z = x3 * COS(2.0f * 3.14159265f * x4);
value.w = x3 * SIN(2.0f * 3.14159265f * x4);
// Record the values.
......
......@@ -895,7 +895,7 @@ void OpenCLIntegrationUtilities::initRandomNumberGenerator(unsigned int randomNu
// Create the random number arrays.
lastSeed = randomNumberSeed;
random = OpenCLArray::create<mm_float4>(context, 32*context.getPaddedNumAtoms(), "random");
random = OpenCLArray::create<mm_float4>(context, 4*context.getPaddedNumAtoms(), "random");
randomSeed = OpenCLArray::create<mm_int4>(context, context.getNumThreadBlocks()*OpenCLContext::ThreadBlockSize, "randomSeed");
randomPos = random->getSize();
......
......@@ -9,7 +9,7 @@ __kernel void generateRandomNumbers(int numValues, __global float4* restrict ran
while (index < numValues) {
float4 value;
// Generate first value.
// Generate first two values.
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
......@@ -33,8 +33,9 @@ __kernel void generateRandomNumbers(int numValues, __global float4* restrict ran
carry = k >> 30;
float x2 = (float)(state.x + state.y + state.w) / (float)0xffffffff;
value.x = x1 * cos(2.0f * 3.14159265f * x2);
value.y = x1 * sin(2.0f * 3.14159265f * x2);
// Generate second value.
// Generate next two values.
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
......@@ -57,57 +58,8 @@ __kernel void generateRandomNumbers(int numValues, __global float4* restrict ran
state.w = m;
carry = k >> 30;
float x4 = (float)(state.x + state.y + state.w) / (float)0xffffffff;
value.y = x3 * cos(2.0f * 3.14159265f * x4);
// Generate third value.
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
state.y ^= state.y >> 17;
state.y ^= state.y << 5;
k = (state.z >> 2) + (state.w >> 3) + (carry >> 2);
m = state.w + state.w + state.z + carry;
state.z = state.w;
state.w = m;
carry = k >> 30;
float x5 = (float)max(state.x + state.y + state.w, 0x00000001u) / (float)0xffffffff;
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
state.y ^= state.y >> 17;
state.y ^= state.y << 5;
x5 = SQRT(-2.0f * LOG(x5));
k = (state.z >> 2) + (state.w >> 3) + (carry >> 2);
m = state.w + state.w + state.z + carry;
state.z = state.w;
state.w = m;
carry = k >> 30;
float x6 = (float)(state.x + state.y + state.w) / (float)0xffffffff;
value.z = x5 * cos(2.0f * 3.14159265f * x6);
// Generate fourth value.
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
state.y ^= state.y >> 17;
state.y ^= state.y << 5;
k = (state.z >> 2) + (state.w >> 3) + (carry >> 2);
m = state.w + state.w + state.z + carry;
state.z = state.w;
state.w = m;
carry = k >> 30;
float x7 = (float)max(state.x + state.y + state.w, 0x00000001u) / (float)0xffffffff;
state.x = state.x * 69069 + 1;
state.y ^= state.y << 13;
state.y ^= state.y >> 17;
state.y ^= state.y << 5;
x7 = SQRT(-2.0f * LOG(x7));
k = (state.z >> 2) + (state.w >> 3) + (carry >> 2);
m = state.w + state.w + state.z + carry;
state.z = state.w;
state.w = m;
carry = k >> 30;
float x8 = (float)(state.x + state.y + state.w) / (float)0xffffffff;
value.w = x7 * cos(2.0f * 3.14159265f * x8);
value.z = x3 * cos(2.0f * 3.14159265f * x4);
value.w = x3 * sin(2.0f * 3.14159265f * x4);
// Record the values.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment