Commit 74a8266f authored by Peter Eastman's avatar Peter Eastman
Browse files

Began implementing CudaPlatform

parent 50643058
...@@ -48,8 +48,9 @@ public: ...@@ -48,8 +48,9 @@ public:
* Create a KernelImpl. * Create a KernelImpl.
* *
* @param name the name of the kernel to create * @param name the name of the kernel to create
* @param context the context the kernel will belong to
*/ */
virtual KernelImpl* createKernelImpl(std::string name, const Platform& platform) const = 0; virtual KernelImpl* createKernelImpl(std::string name, const Platform& platform, OpenMMContextImpl& context) const = 0;
virtual ~KernelFactory() { virtual ~KernelFactory() {
} }
}; };
......
...@@ -40,6 +40,7 @@ namespace OpenMM { ...@@ -40,6 +40,7 @@ namespace OpenMM {
class Kernel; class Kernel;
class KernelFactory; class KernelFactory;
class OpenMMContextImpl;
class StreamFactory; class StreamFactory;
/** /**
...@@ -83,6 +84,16 @@ public: ...@@ -83,6 +84,16 @@ public:
* different StreamFactory has not been registered for the requested stream name. * different StreamFactory has not been registered for the requested stream name.
*/ */
virtual const StreamFactory& getDefaultStreamFactory() const = 0; virtual const StreamFactory& getDefaultStreamFactory() const = 0;
/**
* This is called whenever a new OpenMMContext is created. It gives the Platform a chance to initialize
* the context and store platform-specific data in it.
*/
virtual void contextCreated(OpenMMContextImpl& context) const;
/**
* This is called whenever an OpenMMContext is deleted. It gives the Platform a chance to clean up
* any platform-specific data that was stored in it.
*/
virtual void contextDestroyed(OpenMMContextImpl& context) const;
/** /**
* Register a KernelFactory which should be used to create Kernels with a particular name. * Register a KernelFactory which should be used to create Kernels with a particular name.
* The Platform takes over ownership of the factory, and will delete it when the Platform itself * The Platform takes over ownership of the factory, and will delete it when the Platform itself
...@@ -110,19 +121,20 @@ public: ...@@ -110,19 +121,20 @@ public:
*/ */
bool supportsKernels(std::vector<std::string> kernelNames) const ; bool supportsKernels(std::vector<std::string> kernelNames) const ;
/** /**
* Create a Kernel object. If you call this method multiple times with the same name, * Create a Kernel object. If you call this method multiple times for different contexts with the same name,
* the returned Kernels are independent and do not interact with each other. This means * the returned Kernels are independent and do not interact with each other. This means
* that it is possible to have multiple simulations in progress at one time without them * that it is possible to have multiple simulations in progress at one time without them
* interfering. * interfering.
* *
* If no KernelFactory has been registered for the specified name, this will throw an exception. * If no KernelFactory has been registered for the specified name, this will throw an exception.
* *
* @param the name of the Kernel to get * @param name the name of the Kernel to get
* @param context the context for which to create a Kernel
* @return a newly created Kernel object * @return a newly created Kernel object
*/ */
Kernel createKernel(std::string name) const; Kernel createKernel(std::string name, OpenMMContextImpl& context) const;
/** /**
* Create a Stream object. If you call this method multiple times with the same name, * Create a Stream object. If you call this method multiple times for different contexts with the same name,
* the returned Streams are independent and do not interact with each other. This means * the returned Streams are independent and do not interact with each other. This means
* that it is possible to have multiple simulations in progress at one time without them * that it is possible to have multiple simulations in progress at one time without them
* interfering. * interfering.
...@@ -130,10 +142,11 @@ public: ...@@ -130,10 +142,11 @@ public:
* If a StreamFactory has been registered for the specified name, it will be used to create * If a StreamFactory has been registered for the specified name, it will be used to create
* the Stream. Otherwise, the default StreamFactory will be used. * the Stream. Otherwise, the default StreamFactory will be used.
* *
* @param the name of the Stream to get * @param name the name of the Stream to get
* @param context the context for which to create a Stream
* @return a newly created Stream object * @return a newly created Stream object
*/ */
Stream createStream(std::string name, int size, Stream::DataType type) const; Stream createStream(std::string name, int size, Stream::DataType type, OpenMMContextImpl& context) const;
/** /**
* Register a new Platform. * Register a new Platform.
*/ */
......
...@@ -50,8 +50,9 @@ public: ...@@ -50,8 +50,9 @@ public:
* @param name the name of the stream to create * @param name the name of the stream to create
* @param size the number of elements in the stream * @param size the number of elements in the stream
* @param type the data type of each element in the stream * @param type the data type of each element in the stream
* @param context the context the kernel will belong to
*/ */
virtual StreamImpl* createStreamImpl(std::string name, int size, Stream::DataType type, const Platform& platform) const = 0; virtual StreamImpl* createStreamImpl(std::string name, int size, Stream::DataType type, const Platform& platform, OpenMMContextImpl& context) const = 0;
virtual ~StreamFactory() { virtual ~StreamFactory() {
} }
}; };
......
...@@ -68,6 +68,12 @@ Platform::~Platform() { ...@@ -68,6 +68,12 @@ Platform::~Platform() {
delete *iter; delete *iter;
} }
void Platform::contextCreated(OpenMMContextImpl& context) const {
}
void Platform::contextDestroyed(OpenMMContextImpl& context) const {
}
void Platform::registerKernelFactory(std::string name, KernelFactory* factory) { void Platform::registerKernelFactory(std::string name, KernelFactory* factory) {
kernelFactories[name] = factory; kernelFactories[name] = factory;
} }
...@@ -83,16 +89,16 @@ bool Platform::supportsKernels(std::vector<std::string> kernelNames) const { ...@@ -83,16 +89,16 @@ bool Platform::supportsKernels(std::vector<std::string> kernelNames) const {
return true; return true;
} }
Kernel Platform::createKernel(std::string name) const { Kernel Platform::createKernel(std::string name, OpenMMContextImpl& context) const {
if (kernelFactories.find(name) == kernelFactories.end()) if (kernelFactories.find(name) == kernelFactories.end())
throw PlatformException("Called createKernel() on a Platform which does not support the requested kernel"); throw PlatformException("Called createKernel() on a Platform which does not support the requested kernel");
return Kernel(kernelFactories.find(name)->second->createKernelImpl(name, *this)); return Kernel(kernelFactories.find(name)->second->createKernelImpl(name, *this, context));
} }
Stream Platform::createStream(std::string name, int size, Stream::DataType type) const { Stream Platform::createStream(std::string name, int size, Stream::DataType type, OpenMMContextImpl& context) const {
if (streamFactories.find(name) == streamFactories.end()) if (streamFactories.find(name) == streamFactories.end())
return Stream(getDefaultStreamFactory().createStreamImpl(name, size, type, *this)); return Stream(getDefaultStreamFactory().createStreamImpl(name, size, type, *this, context));
return Stream(streamFactories.find(name)->second->createStreamImpl(name, size, type, *this)); return Stream(streamFactories.find(name)->second->createStreamImpl(name, size, type, *this, context));
} }
void Platform::registerPlatform(Platform* platform) { void Platform::registerPlatform(Platform* platform) {
......
...@@ -148,6 +148,14 @@ public: ...@@ -148,6 +148,14 @@ public:
* Delete all ForceImpl objects that have been created and create new ones. * Delete all ForceImpl objects that have been created and create new ones.
*/ */
void reinitialize(); void reinitialize();
/**
* Get the platform-specific data stored in this context.
*/
void* getPlatformData();
/**
* Set the platform-specific data stored in this context.
*/
void setPlatformData(void* data);
private: private:
friend class OpenMMContext; friend class OpenMMContext;
OpenMMContext& owner; OpenMMContext& owner;
...@@ -159,6 +167,7 @@ private: ...@@ -159,6 +167,7 @@ private:
Platform* platform; Platform* platform;
Stream positions, velocities, forces; Stream positions, velocities, forces;
Kernel kineticEnergyKernel; Kernel kineticEnergyKernel;
void* platformData;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -43,7 +43,7 @@ AndersenThermostatImpl::AndersenThermostatImpl(AndersenThermostat& owner) : owne ...@@ -43,7 +43,7 @@ AndersenThermostatImpl::AndersenThermostatImpl(AndersenThermostat& owner) : owne
} }
void AndersenThermostatImpl::initialize(OpenMMContextImpl& context) { void AndersenThermostatImpl::initialize(OpenMMContextImpl& context) {
kernel = context.getPlatform().createKernel(ApplyAndersenThermostatKernel::Name()); kernel = context.getPlatform().createKernel(ApplyAndersenThermostatKernel::Name(), context);
const System& system = context.getSystem(); const System& system = context.getSystem();
vector<double> masses(system.getNumAtoms()); vector<double> masses(system.getNumAtoms());
for (int i = 0; i < system.getNumAtoms(); ++i) for (int i = 0; i < system.getNumAtoms(); ++i)
......
...@@ -47,7 +47,7 @@ BrownianIntegrator::BrownianIntegrator(double temperature, double frictionCoeff, ...@@ -47,7 +47,7 @@ BrownianIntegrator::BrownianIntegrator(double temperature, double frictionCoeff,
void BrownianIntegrator::initialize(OpenMMContextImpl& contextRef) { void BrownianIntegrator::initialize(OpenMMContextImpl& contextRef) {
context = &contextRef; context = &contextRef;
kernel = context->getPlatform().createKernel(IntegrateBrownianStepKernel::Name()); kernel = context->getPlatform().createKernel(IntegrateBrownianStepKernel::Name(), contextRef);
const System& system = context->getSystem(); const System& system = context->getSystem();
vector<double> masses(system.getNumAtoms()); vector<double> masses(system.getNumAtoms());
vector<std::vector<int> > constraintIndices(system.getNumConstraints()); vector<std::vector<int> > constraintIndices(system.getNumConstraints());
......
...@@ -43,7 +43,7 @@ CMMotionRemoverImpl::CMMotionRemoverImpl(CMMotionRemover& owner) : owner(owner) ...@@ -43,7 +43,7 @@ CMMotionRemoverImpl::CMMotionRemoverImpl(CMMotionRemover& owner) : owner(owner)
} }
void CMMotionRemoverImpl::initialize(OpenMMContextImpl& context) { void CMMotionRemoverImpl::initialize(OpenMMContextImpl& context) {
kernel = context.getPlatform().createKernel(RemoveCMMotionKernel::Name()); kernel = context.getPlatform().createKernel(RemoveCMMotionKernel::Name(), context);
const System& system = context.getSystem(); const System& system = context.getSystem();
vector<double> masses(system.getNumAtoms()); vector<double> masses(system.getNumAtoms());
for (int i = 0; i < system.getNumAtoms(); ++i) for (int i = 0; i < system.getNumAtoms(); ++i)
......
...@@ -41,7 +41,7 @@ GBSAOBCForceFieldImpl::GBSAOBCForceFieldImpl(GBSAOBCForceField& owner) : owner(o ...@@ -41,7 +41,7 @@ GBSAOBCForceFieldImpl::GBSAOBCForceFieldImpl(GBSAOBCForceField& owner) : owner(o
} }
void GBSAOBCForceFieldImpl::initialize(OpenMMContextImpl& context) { void GBSAOBCForceFieldImpl::initialize(OpenMMContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcGBSAOBCForceFieldKernel::Name()); kernel = context.getPlatform().createKernel(CalcGBSAOBCForceFieldKernel::Name(), context);
vector<vector<double> > atomParameters(owner.getNumAtoms()); vector<vector<double> > atomParameters(owner.getNumAtoms());
for (int i = 0; i < owner.getNumAtoms(); ++i) { for (int i = 0; i < owner.getNumAtoms(); ++i) {
double charge, radius, scalingFactor; double charge, radius, scalingFactor;
......
...@@ -47,7 +47,7 @@ LangevinIntegrator::LangevinIntegrator(double temperature, double frictionCoeff, ...@@ -47,7 +47,7 @@ LangevinIntegrator::LangevinIntegrator(double temperature, double frictionCoeff,
void LangevinIntegrator::initialize(OpenMMContextImpl& contextRef) { void LangevinIntegrator::initialize(OpenMMContextImpl& contextRef) {
context = &contextRef; context = &contextRef;
kernel = context->getPlatform().createKernel(IntegrateLangevinStepKernel::Name()); kernel = context->getPlatform().createKernel(IntegrateLangevinStepKernel::Name(), contextRef);
const System& system = context->getSystem(); const System& system = context->getSystem();
vector<double> masses(system.getNumAtoms()); vector<double> masses(system.getNumAtoms());
vector<std::vector<int> > constraintIndices(system.getNumConstraints()); vector<std::vector<int> > constraintIndices(system.getNumConstraints());
......
...@@ -45,7 +45,7 @@ using std::vector; ...@@ -45,7 +45,7 @@ using std::vector;
using std::string; using std::string;
OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integrator& integrator, Platform* platform) : OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integrator& integrator, Platform* platform) :
owner(owner), system(system), integrator(integrator), platform(platform) { owner(owner), system(system), integrator(integrator), platform(platform), platformData(NULL) {
vector<string> kernelNames; vector<string> kernelNames;
kernelNames.push_back(CalcKineticEnergyKernel::Name()); kernelNames.push_back(CalcKineticEnergyKernel::Name());
for (int i = 0; i < system.getNumForces(); ++i) { for (int i = 0; i < system.getNumForces(); ++i) {
...@@ -61,12 +61,13 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ ...@@ -61,12 +61,13 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ
this->platform = platform = &Platform::findPlatform(kernelNames); this->platform = platform = &Platform::findPlatform(kernelNames);
else if (!platform->supportsKernels(kernelNames)) else if (!platform->supportsKernels(kernelNames))
throw OpenMMException("Specified a Platform for an OpenMMContext which does not support all required kernels"); throw OpenMMException("Specified a Platform for an OpenMMContext which does not support all required kernels");
positions = platform->createStream("atomPositions", system.getNumAtoms(), Stream::Double3); platform->contextCreated(*this);
velocities = platform->createStream("atomVelocities", system.getNumAtoms(), Stream::Double3); positions = platform->createStream("atomPositions", system.getNumAtoms(), Stream::Double3, *this);
forces = platform->createStream("atomForces", system.getNumAtoms(), Stream::Double3); velocities = platform->createStream("atomVelocities", system.getNumAtoms(), Stream::Double3, *this);
forces = platform->createStream("atomForces", system.getNumAtoms(), Stream::Double3, *this);
double zero[] = {0.0, 0.0, 0.0}; double zero[] = {0.0, 0.0, 0.0};
velocities.fillWithValue(&zero); velocities.fillWithValue(&zero);
kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name()); kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name(), *this);
vector<double> masses(system.getNumAtoms()); vector<double> masses(system.getNumAtoms());
for (int i = 0; i < masses.size(); ++i) for (int i = 0; i < masses.size(); ++i)
masses[i] = system.getAtomMass(i); masses[i] = system.getAtomMass(i);
...@@ -79,6 +80,7 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ ...@@ -79,6 +80,7 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ
OpenMMContextImpl::~OpenMMContextImpl() { OpenMMContextImpl::~OpenMMContextImpl() {
for (int i = 0; i < (int) forceImpls.size(); ++i) for (int i = 0; i < (int) forceImpls.size(); ++i)
delete forceImpls[i]; delete forceImpls[i];
platform->contextDestroyed(*this);
} }
double OpenMMContextImpl::getParameter(std::string name) { double OpenMMContextImpl::getParameter(std::string name) {
...@@ -126,3 +128,11 @@ void OpenMMContextImpl::reinitialize() { ...@@ -126,3 +128,11 @@ void OpenMMContextImpl::reinitialize() {
} }
integrator.initialize(*this); integrator.initialize(*this);
} }
void* OpenMMContextImpl::getPlatformData() {
return platformData;
}
void OpenMMContextImpl::setPlatformData(void* data) {
platformData = data;
}
...@@ -45,7 +45,7 @@ StandardMMForceFieldImpl::~StandardMMForceFieldImpl() { ...@@ -45,7 +45,7 @@ StandardMMForceFieldImpl::~StandardMMForceFieldImpl() {
} }
void StandardMMForceFieldImpl::initialize(OpenMMContextImpl& context) { void StandardMMForceFieldImpl::initialize(OpenMMContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcStandardMMForceFieldKernel::Name()); kernel = context.getPlatform().createKernel(CalcStandardMMForceFieldKernel::Name(), context);
vector<vector<int> > bondIndices(owner.getNumBonds()); vector<vector<int> > bondIndices(owner.getNumBonds());
vector<vector<double> > bondParameters(owner.getNumBonds()); vector<vector<double> > bondParameters(owner.getNumBonds());
vector<vector<int> > angleIndices(owner.getNumAngles()); vector<vector<int> > angleIndices(owner.getNumAngles());
......
...@@ -45,7 +45,7 @@ VerletIntegrator::VerletIntegrator(double stepSize) { ...@@ -45,7 +45,7 @@ VerletIntegrator::VerletIntegrator(double stepSize) {
void VerletIntegrator::initialize(OpenMMContextImpl& contextRef) { void VerletIntegrator::initialize(OpenMMContextImpl& contextRef) {
context = &contextRef; context = &contextRef;
kernel = context->getPlatform().createKernel(IntegrateVerletStepKernel::Name()); kernel = context->getPlatform().createKernel(IntegrateVerletStepKernel::Name(), contextRef);
const System& system = context->getSystem(); const System& system = context->getSystem();
vector<double> masses(system.getNumAtoms()); vector<double> masses(system.getNumAtoms());
vector<std::vector<int> > constraintIndices(system.getNumConstraints()); vector<std::vector<int> > constraintIndices(system.getNumConstraints());
......
...@@ -42,7 +42,7 @@ namespace OpenMM { ...@@ -42,7 +42,7 @@ namespace OpenMM {
class BrookKernelFactory : public KernelFactory { class BrookKernelFactory : public KernelFactory {
public: public:
KernelImpl* createKernelImpl(std::string name, const Platform& platform) const; KernelImpl* createKernelImpl(std::string name, const Platform& platform, OpenMMContextImpl& context) const;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -42,7 +42,7 @@ namespace OpenMM { ...@@ -42,7 +42,7 @@ namespace OpenMM {
class BrookStreamFactory : public StreamFactory { class BrookStreamFactory : public StreamFactory {
public: public:
StreamImpl* createStreamImpl(std::string name, int size, Stream::DataType type, int streamWidth, const Platform& platform) const; StreamImpl* createStreamImpl(std::string name, int size, Stream::DataType type, int streamWidth, const Platform& platform, OpenMMContextImpl& context) const;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
using namespace OpenMM; using namespace OpenMM;
KernelImpl* BrookKernelFactory::createKernelImpl(std::string name, const Platform& platform) const { KernelImpl* BrookKernelFactory::createKernelImpl(std::string name, const Platform& platform, OpenMMContextImpl& context) const {
if (name == CalcStandardMMForceFieldKernel::Name()) if (name == CalcStandardMMForceFieldKernel::Name())
return new BrookCalcStandardMMForceFieldKernel(name, platform); return new BrookCalcStandardMMForceFieldKernel(name, platform);
if (name == CalcGBSAOBCForceFieldKernel::Name()) if (name == CalcGBSAOBCForceFieldKernel::Name())
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
using namespace OpenMM; using namespace OpenMM;
StreamImpl* BrookStreamFactory::createStreamImpl(std::string name, int size, Stream::DataType type, int streamWidth, const Platform& platform) const { StreamImpl* BrookStreamFactory::createStreamImpl(std::string name, int size, Stream::DataType type, int streamWidth, const Platform& platform, OpenMMContextImpl& context) const {
switch (type) { switch (type) {
case Stream::Float: case Stream::Float:
case Stream::Float2: case Stream::Float2:
......
#ifndef OPENMM_CUDAKERNELFACTORY_H_
#define OPENMM_CUDAKERNELFACTORY_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "KernelFactory.h"
namespace OpenMM {
/**
* This KernelFactory creates all kernels for CudaPlatform.
*/
class CudaKernelFactory : public KernelFactory {
public:
KernelImpl* createKernelImpl(std::string name, const Platform& platform, OpenMMContextImpl& context) const;
};
} // namespace OpenMM
#endif /*OPENMM_CUDAKERNELFACTORY_H_*/
#ifndef OPENMM_CUDAPLATFORM_H_
#define OPENMM_CUDAPLATFORM_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "Platform.h"
#include "CudaStreamFactory.h"
namespace OpenMM {
/**
* This Platform subclass uses CUDA implementations of the OpenMM kernels to run on NVidia GPUs.
*/
class CudaPlatform : public Platform {
public:
CudaPlatform();
std::string getName() const {
return "Cuda";
}
double getSpeed() const {
return 100;
}
bool supportsDoublePrecision() const;
const StreamFactory& getDefaultStreamFactory() const;
void contextCreated(OpenMMContextImpl& context) const;
void contextDestroyed(OpenMMContextImpl& context) const;
private:
CudaStreamFactory defaultStreamFactory;
};
} // namespace OpenMM
#endif /*OPENMM_CUDAPLATFORM_H_*/
#ifndef OPENMM_CUDASTREAMFACTORY_H_
#define OPENMM_CUDASTREAMFACTORY_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "StreamFactory.h"
namespace OpenMM {
/**
* This StreamFactory creates all streams for CudaPlatform.
*/
class CudaStreamFactory : public StreamFactory {
public:
StreamImpl* createStreamImpl(std::string name, int size, Stream::DataType type, const Platform& platform, OpenMMContextImpl& context) const;
};
} // namespace OpenMM
#endif /*OPENMM_CUDASTREAMFACTORY_H_*/
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment