Commit d91e3f19 authored by peastman's avatar peastman
Browse files

Improved logic for selecting a platform to better deal with exceptions in context initialization

parent fc1d0b5c
......@@ -39,6 +39,7 @@
#include "openmm/State.h"
#include "openmm/VirtualSite.h"
#include "openmm/Context.h"
#include <algorithm>
#include <iostream>
#include <map>
#include <utility>
......@@ -104,14 +105,40 @@ ContextImpl::ContextImpl(Context& owner, const System& system, Integrator& integ
hasInitializedForces = true;
vector<string> integratorKernels = integrator.getKernelNames();
kernelNames.insert(kernelNames.begin(), integratorKernels.begin(), integratorKernels.end());
if (platform == 0)
this->platform = platform = &Platform::findPlatform(kernelNames);
else if (!platform->supportsKernels(kernelNames))
// Select a platform to use.
vector<pair<double, Platform*> > candidatePlatforms;
if (platform == NULL) {
for (int i = 0; i < Platform::getNumPlatforms(); i++) {
Platform& p = Platform::getPlatform(i);
if (p.supportsKernels(kernelNames))
candidatePlatforms.push_back(make_pair(p.getSpeed(), &p));
}
if (candidatePlatforms.size() == 0)
throw OpenMMException("No Platform supports all the requested kernels");
sort(candidatePlatforms.begin(), candidatePlatforms.end());
}
else {
if (!platform->supportsKernels(kernelNames))
throw OpenMMException("Specified a Platform for a Context which does not support all required kernels");
candidatePlatforms.push_back(make_pair(platform->getSpeed(), platform));
}
for (int i = candidatePlatforms.size()-1; i >= 0; i--) {
try {
this->platform = platform = candidatePlatforms[i].second;
platform->contextCreated(*this, properties);
break;
}
catch (...) {
if (i > 0)
continue;
throw;
}
}
// Create and initialize kernels and other objects.
platform->contextCreated(*this, properties);
initializeForcesKernel = platform->createKernel(CalcForcesAndEnergyKernel::Name(), *this);
initializeForcesKernel.getAs<CalcForcesAndEnergyKernel>().initialize(system);
updateStateDataKernel = platform->createKernel(UpdateStateDataKernel::Name(), *this);
......
......@@ -153,9 +153,6 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
CHECK_RESULT(cuDeviceGetAttribute(&multiprocessors, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
int numThreadBlocksPerComputeUnit = 6;
numThreadBlocks = numThreadBlocksPerComputeUnit*multiprocessors;
bonded = new CudaBondedUtilities(*this);
nonbonded = new CudaNonbondedUtilities(*this);
int numEnergyBuffers = max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers());
if (useDoublePrecision) {
posq = CudaArray::create<double4>(*this, paddedNumAtoms, "posq");
velm = CudaArray::create<double4>(*this, paddedNumAtoms, "velm");
......@@ -166,9 +163,6 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
compilationDefines["make_mixed2"] = "make_double2";
compilationDefines["make_mixed3"] = "make_double3";
compilationDefines["make_mixed4"] = "make_double4";
energyBuffer = CudaArray::create<double>(*this, numEnergyBuffers, "energyBuffer");
int pinnedBufferSize = max(paddedNumAtoms*4, numEnergyBuffers);
CHECK_RESULT(cuMemHostAlloc(&pinnedBuffer, pinnedBufferSize*sizeof(double), 0));
}
else if (useMixedPrecision) {
posq = CudaArray::create<float4>(*this, paddedNumAtoms, "posq");
......@@ -181,9 +175,6 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
compilationDefines["make_mixed2"] = "make_double2";
compilationDefines["make_mixed3"] = "make_double3";
compilationDefines["make_mixed4"] = "make_double4";
energyBuffer = CudaArray::create<float>(*this, numEnergyBuffers, "energyBuffer");
int pinnedBufferSize = max(paddedNumAtoms*4, numEnergyBuffers);
CHECK_RESULT(cuMemHostAlloc(&pinnedBuffer, pinnedBufferSize*sizeof(double), 0));
}
else {
posq = CudaArray::create<float4>(*this, paddedNumAtoms, "posq");
......@@ -194,9 +185,6 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
compilationDefines["make_mixed2"] = "make_float2";
compilationDefines["make_mixed3"] = "make_float3";
compilationDefines["make_mixed4"] = "make_float4";
energyBuffer = CudaArray::create<float>(*this, numEnergyBuffers, "energyBuffer");
int pinnedBufferSize = max(paddedNumAtoms*6, numEnergyBuffers);
CHECK_RESULT(cuMemHostAlloc(&pinnedBuffer, pinnedBufferSize*sizeof(float), 0));
}
posCellOffsets.resize(paddedNumAtoms, make_int4(0, 0, 0, 0));
......@@ -233,6 +221,8 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
// Create utilities objects.
bonded = new CudaBondedUtilities(*this);
nonbonded = new CudaNonbondedUtilities(*this);
integration = new CudaIntegrationUtilities(*this, system);
expression = new CudaExpressionUtilities(*this);
}
......@@ -280,6 +270,22 @@ CudaContext::~CudaContext() {
void CudaContext::initialize() {
cuCtxSetCurrent(context);
string errorMessage = "Error initializing Context";
int numEnergyBuffers = max(numThreadBlocks*ThreadBlockSize, nonbonded->getNumEnergyBuffers());
if (useDoublePrecision) {
energyBuffer = CudaArray::create<double>(*this, numEnergyBuffers, "energyBuffer");
int pinnedBufferSize = max(paddedNumAtoms*4, numEnergyBuffers);
CHECK_RESULT(cuMemHostAlloc(&pinnedBuffer, pinnedBufferSize*sizeof(double), 0));
}
else if (useMixedPrecision) {
energyBuffer = CudaArray::create<float>(*this, numEnergyBuffers, "energyBuffer");
int pinnedBufferSize = max(paddedNumAtoms*4, numEnergyBuffers);
CHECK_RESULT(cuMemHostAlloc(&pinnedBuffer, pinnedBufferSize*sizeof(double), 0));
}
else {
energyBuffer = CudaArray::create<float>(*this, numEnergyBuffers, "energyBuffer");
int pinnedBufferSize = max(paddedNumAtoms*6, numEnergyBuffers);
CHECK_RESULT(cuMemHostAlloc(&pinnedBuffer, pinnedBufferSize*sizeof(float), 0));
}
for (int i = 0; i < numAtoms; i++) {
double mass = system.getParticleMass(i);
if (useDoublePrecision || useMixedPrecision)
......
......@@ -1460,8 +1460,9 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
int numParticles = force.getNumParticles();
sigmaEpsilon = CudaArray::create<float2>(cu, cu.getPaddedNumAtoms(), "sigmaEpsilon");
CudaArray& posq = cu.getPosq();
float4* posqf = (float4*) cu.getPinnedBuffer();
double4* posqd = (double4*) cu.getPinnedBuffer();
vector<double4> temp(posq.getSize());
float4* posqf = (float4*) &temp[0];
double4* posqd = (double4*) &temp[0];
vector<float2> sigmaEpsilonVector(cu.getPaddedNumAtoms(), make_float2(0, 0));
vector<vector<int> > exclusionList(numParticles);
double sumSquaredCharges = 0.0;
......@@ -1486,7 +1487,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
exclusionList[exclusions[i].first].push_back(exclusions[i].second);
exclusionList[exclusions[i].second].push_back(exclusions[i].first);
}
posq.upload(cu.getPinnedBuffer());
posq.upload(&temp[0]);
sigmaEpsilon->upload(sigmaEpsilonVector);
bool useCutoff = (force.getNonbondedMethod() != NonbondedForce::NoCutoff);
bool usePeriodic = (force.getNonbondedMethod() != NonbondedForce::NoCutoff && force.getNonbondedMethod() != NonbondedForce::CutoffNonPeriodic);
......@@ -2410,8 +2411,9 @@ void CudaCalcGBSAOBCForceKernel::initialize(const System& system, const GBSAOBCF
cu.addAutoclearBuffer(*bornSum);
cu.addAutoclearBuffer(*bornForce);
CudaArray& posq = cu.getPosq();
float4* posqf = (float4*) cu.getPinnedBuffer();
double4* posqd = (double4*) cu.getPinnedBuffer();
vector<double4> temp(posq.getSize());
float4* posqf = (float4*) &temp[0];
double4* posqd = (double4*) &temp[0];
vector<float2> paramsVector(cu.getPaddedNumAtoms(), make_float2(1, 1));
const double dielectricOffset = 0.009;
for (int i = 0; i < force.getNumParticles(); i++) {
......@@ -2424,7 +2426,7 @@ void CudaCalcGBSAOBCForceKernel::initialize(const System& system, const GBSAOBCF
else
posqf[i] = make_float4(0, 0, 0, (float) charge);
}
posq.upload(cu.getPinnedBuffer());
posq.upload(&temp[0]);
params->upload(paramsVector);
prefactor = -ONE_4PI_EPS0*((1.0/force.getSoluteDielectric())-(1.0/force.getSolventDielectric()));
bool useCutoff = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff);
......
......@@ -174,6 +174,7 @@ CudaPlatform::PlatformData::PlatformData(ContextImpl* context, const System& sys
searchPos = nextPos+1;
}
devices.push_back(deviceIndexProperty.substr(searchPos));
try {
for (int i = 0; i < (int) devices.size(); i++) {
if (devices[i].length() > 0) {
unsigned int deviceIndex;
......@@ -183,6 +184,14 @@ CudaPlatform::PlatformData::PlatformData(ContextImpl* context, const System& sys
}
if (contexts.size() == 0)
contexts.push_back(new CudaContext(system, -1, blocking, precisionProperty, compilerProperty, tempProperty, *this));
}
catch (...) {
// If an exception was thrown, do our best to clean up memory.
for (int i = 0; i < (int) contexts.size(); i++)
delete contexts[i];
throw;
}
stringstream deviceIndex, deviceName;
for (int i = 0; i < (int) contexts.size(); i++) {
if (i > 0) {
......
......@@ -253,8 +253,6 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
bonded = new OpenCLBondedUtilities(*this);
nonbonded = new OpenCLNonbondedUtilities(*this);
if (useDoublePrecision) {
posq = OpenCLArray::create<mm_double4>(*this, paddedNumAtoms, "posq");
velm = OpenCLArray::create<mm_double4>(*this, paddedNumAtoms, "velm");
......@@ -343,6 +341,8 @@ OpenCLContext::OpenCLContext(const System& system, int platformIndex, int device
// Create utilities objects.
bonded = new OpenCLBondedUtilities(*this);
nonbonded = new OpenCLNonbondedUtilities(*this);
integration = new OpenCLIntegrationUtilities(*this, system);
expression = new OpenCLExpressionUtilities(*this);
}
......
......@@ -143,6 +143,7 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p
searchPos = nextPos+1;
}
devices.push_back(deviceIndexProperty.substr(searchPos));
try {
for (int i = 0; i < (int) devices.size(); i++) {
if (devices[i].length() > 0) {
unsigned int deviceIndex;
......@@ -152,6 +153,14 @@ OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& p
}
if (contexts.size() == 0)
contexts.push_back(new OpenCLContext(system, platformIndex, -1, precisionProperty, *this));
}
catch (...) {
// If an exception was thrown, do our best to clean up memory.
for (int i = 0; i < (int) contexts.size(); i++)
delete contexts[i];
throw;
}
stringstream deviceIndex, deviceName;
for (int i = 0; i < (int) contexts.size(); i++) {
if (i > 0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment