Commit 0b5d58d7 authored by Charlles Abreu's avatar Charlles Abreu
Browse files

Conflict resolution in TestSplineFilter.cpp

parents 9026dbe7 b0d13582
...@@ -40,6 +40,8 @@ ...@@ -40,6 +40,8 @@
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
#include "sfmt/SFMT.h" #include "sfmt/SFMT.h"
#include <iostream> #include <iostream>
#include <algorithm>
#include <numeric>
#include <vector> #include <vector>
using namespace OpenMM; using namespace OpenMM;
...@@ -47,7 +49,7 @@ using namespace std; ...@@ -47,7 +49,7 @@ using namespace std;
const double TOL = 1e-5; const double TOL = 1e-5;
void testVVSingleBond() { void testSingleBond() {
System system; System system;
system.addParticle(2.0); system.addParticle(2.0);
system.addParticle(2.0); system.addParticle(2.0);
...@@ -82,11 +84,11 @@ void testVVSingleBond() { ...@@ -82,11 +84,11 @@ void testVVSingleBond() {
ASSERT_EQUAL_TOL(10.0, context.getState(0).getTime(), 1e-5); ASSERT_EQUAL_TOL(10.0, context.getState(0).getTime(), 1e-5);
} }
void testVVConstraints() { void testConstraints() {
const int numParticles = 8; const int numParticles = 8;
const int numConstraints = 5; const int numConstraints = 5;
System system; System system;
NoseHooverIntegrator integrator(0.001); NoseHooverIntegrator integrator(0.0005);
integrator.setConstraintTolerance(1e-5); integrator.setConstraintTolerance(1e-5);
NonbondedForce* forceField = new NonbondedForce(); NonbondedForce* forceField = new NonbondedForce();
for (int i = 0; i < numParticles; ++i) { for (int i = 0; i < numParticles; ++i) {
...@@ -135,10 +137,10 @@ void testVVConstraints() { ...@@ -135,10 +137,10 @@ void testVVConstraints() {
} }
} }
void testVVConstrainedClusters() { void testConstrainedClusters() {
const int numParticles = 7; const int numParticles = 7;
System system; System system;
NoseHooverIntegrator integrator(0.001); NoseHooverIntegrator integrator(0.0005);
integrator.setConstraintTolerance(1e-5); integrator.setConstraintTolerance(1e-5);
NonbondedForce* forceField = new NonbondedForce(); NonbondedForce* forceField = new NonbondedForce();
for (int i = 0; i < numParticles; ++i) { for (int i = 0; i < numParticles; ++i) {
...@@ -197,7 +199,7 @@ void testVVConstrainedClusters() { ...@@ -197,7 +199,7 @@ void testVVConstrainedClusters() {
} }
} }
void testVVConstrainedMasslessParticles() { void testConstrainedMasslessParticles() {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
system.addParticle(1.0); system.addParticle(1.0);
...@@ -297,17 +299,235 @@ void testInitialTemperature() { ...@@ -297,17 +299,235 @@ void testInitialTemperature() {
ASSERT_USUALLY_EQUAL_TOL(targetTemperature, temperature, 0.01); ASSERT_USUALLY_EQUAL_TOL(targetTemperature, temperature, 0.01);
} }
void testHarmonicOscillator() {
const double mass = 1.0;
double temperature = 300;
double frequency = 1;
double mts = 1, ys = 1, chain_length = 3;
System system;
system.addParticle(mass);
vector<Vec3> positions(1);
positions[0] = Vec3(0.5,0.5,0.5);
vector<Vec3> velocities(1);
velocities[0] = Vec3(0, 0, 0);
auto harmonic_restraint = new CustomExternalForce("0.5*(x^2+y^2+z^2)");
harmonic_restraint->addParticle(0);
system.addForce(harmonic_restraint);
NoseHooverIntegrator integrator(0.001);
integrator.addThermostat(temperature, frequency, chain_length, mts, ys);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setVelocities(velocities);
double mean_temperature=0;
// equilibration
integrator.step(2000);
for (size_t i=0; i < 2500; i++){
integrator.step(10);
State state = context.getState(State::Energy | State::Positions | State::Velocities);
double kinetic_energy = state.getKineticEnergy();
double temp = kinetic_energy/(0.5*3*BOLTZ);
mean_temperature = (i*mean_temperature + temp)/(i+1);
double PE = state.getPotentialEnergy();
double time = state.getTime();
double energy = kinetic_energy + PE + integrator.computeHeatBathEnergy();
}
ASSERT_EQUAL_TOL(temperature, mean_temperature, 0.02);
}
int makeDimerBox(System& system, std::vector<Vec3>& positions, bool constrain=true, int numMolecules=20, double bondLength=0.1){
double boxLength = 2; // nm
Vec3 a(boxLength, 0.0, 0.0);
Vec3 b(0.0, boxLength, 0.0);
Vec3 c(0.0, 0.0, boxLength);
double mass = 20;
double bondForceConstant = 30000; //0.001;
int numDOF = 0;
NonbondedForce* forceField = new NonbondedForce();
HarmonicBondForce* bondForce = new HarmonicBondForce();
for(int molecule = 0; molecule < numMolecules; ++molecule) {
int particle1 = system.addParticle(mass);
int particle2 = system.addParticle(mass);
forceField->addParticle(0.0, 0.1, 1.0);
forceField->addParticle(0.0, 0.1, 1.0);
forceField->addException(particle1, particle2, 0, 0, 0);
bondForce->addBond(particle1, particle2, bondLength, bondForceConstant);
numDOF += 6;
if (constrain) {
system.addConstraint(particle1, particle2, bondLength);
numDOF -= 1;
}
}
forceField->setCutoffDistance(.99*boxLength/2);
forceField->setSwitchingDistance(.88*boxLength/2);
forceField->setUseSwitchingFunction(true);
forceField->setUseDispersionCorrection(false);
forceField->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
system.addForce(forceField);
system.addForce(bondForce);
system.setDefaultPeriodicBoxVectors(a, b, c);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numMolecules; i++) {
while (true) {
Vec3 pos = Vec3(boxLength*genrand_real2(sfmt), boxLength*genrand_real2(sfmt), boxLength*genrand_real2(sfmt));
Vec3 pos1 = pos + Vec3(0,0, bondLength/2);
Vec3 pos2 = pos + Vec3(0,0,-bondLength/2);
double minDist = 2*boxLength;
for (int j = 0; j < i; j++) {
Vec3 delta = pos1-positions[j];
minDist = std::min(minDist, sqrt(delta.dot(delta)));
delta = pos2-positions[j];
minDist = std::min(minDist, sqrt(delta.dot(delta)));
}
if (minDist > 0.15) {
positions[2*i+0] = pos1;
positions[2*i+1] = pos2;
break;
}
}
}
return numDOF;
}
void testDimerBox(bool constrain=true) {
// Check conservation of system + bath energy for a harmonic oscillator
int numMolecules = 20;
double bondLength = 0.1;
double bondLengthSquared = bondLength * bondLength;
System system;
std::vector<Vec3> positions(numMolecules*2);
int numDOF = makeDimerBox(system, positions, constrain, numMolecules, bondLength);
bool simpleConstruct = true;
double temperature = 300; // kelvin
double collisionFrequency = 200; // 1/ps
int numMTS = 3;
int numYS = 3;
int chainLength = 5;
auto integrator = simpleConstruct ? NoseHooverIntegrator(temperature, collisionFrequency, 0.001, chainLength, numMTS, numYS)
: NoseHooverIntegrator(0.001);
if (!simpleConstruct)
integrator.addThermostat(temperature, collisionFrequency, chainLength, numMTS, numYS);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setVelocitiesToTemperature(temperature);
int nSteps = 5000;
double mean_temp = 0.0;
std::vector<double> energies(nSteps);
for (int i = 0; i < nSteps; ++i) {
integrator.step(1);
State state = context.getState(State::Energy | (constrain ? State::Positions : 0));
if (constrain) {
auto positions = state.getPositions();
for(int i = 0; i < numMolecules; ++i) {
Vec3 delta = positions[2*i+1] - positions[2*i];
double dR2 = delta.dot(delta);
ASSERT_EQUAL_TOL(bondLengthSquared, dR2, 1e-4);
}
}
double KE = state.getKineticEnergy();
double PE = state.getPotentialEnergy();
double time = state.getTime();
double instantaneous_temperature = 2 * KE / (BOLTZ * numDOF);
mean_temp = (i*mean_temp + instantaneous_temperature)/(i+1);
double energy = KE + PE + integrator.computeHeatBathEnergy();
energies[i] = energy;
}
double sum = std::accumulate(energies.begin(), energies.end(), 0.0);
double mean = sum / energies.size();
double sq_sum = std::inner_product(energies.begin(), energies.end(), energies.begin(), 0.0);
double std = std::sqrt(sq_sum / energies.size() - mean * mean);
double relative_std = std / mean;
// Check mean temperature
ASSERT_USUALLY_EQUAL_TOL(temperature, mean_temp, 1e-2);
// Check fluctuation of conserved (total bath + system) energy
ASSERT_USUALLY_EQUAL_TOL(relative_std, 0, 5e-3);
}
void testCheckpoints() {
// Create a system with Drude-like particles to be thermostated as a pair, as well as another
// particle to be thermostated independently, to test all integrator features.
double timeStep = 0.001;
NoseHooverIntegrator integrator(timeStep), newIntegrator(timeStep);
System system;
double mass = 1;
system.addParticle(8*mass);
system.addParticle(mass);
system.addParticle(5*mass);
HarmonicBondForce* force = new HarmonicBondForce();
force->addBond(0, 1, 0.1, 50.0);
force->addBond(0, 2, 0.1, 50.0);
system.addForce(force);
double kineticEnergy = 1e6;
double temperature=300, collisionFrequency=1, chainLength=3, numMTS=3, numYS=3;
chainLength = 10;
integrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
newIntegrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
Context context(system, integrator, platform);
Context newContext(system, newIntegrator, platform);
std::vector<Vec3> positions(3);
std::vector<Vec3> velocities(3);
positions[1] = {0.1, 0.0, 0.0};
velocities[1] = {0.1,0.2,-0.2};
positions[2] = {-0.1, 0.001, 0.001};
velocities[2] = {-0.1,0.2,-0.2};
context.setPositions(positions);
context.setVelocities(velocities);
// Run a short simulation and checkpoint..
integrator.step(500);
std::stringstream checkpoint;
context.createCheckpoint(checkpoint);
// Now continue the simulation
integrator.step(5);
// And try the same, starting from the checkpoint
newContext.loadCheckpoint(checkpoint);
newIntegrator.step(5);
State state1 = context.getState(State::Positions | State::Velocities);
State state2 = newContext.getState(State::Positions | State::Velocities);
ASSERT_EQUAL_VEC(state1.getPositions()[0], state2.getPositions()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getPositions()[1], state2.getPositions()[1], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[0], state2.getVelocities()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6);
}
void testAPIChangeNumParticles() {
bool constrain = true;
int numMolecules = 20;
double bondLength = 0.1;
double bondLengthSquared = bondLength * bondLength;
System system;
std::vector<Vec3> positions(numMolecules*2);
int numDOF = makeDimerBox(system, positions, constrain, numMolecules, bondLength);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
initializeTests(argc, argv); initializeTests(argc, argv);
testVVSingleBond(); // Underlying integrator tests
testVVConstraints(); testSingleBond();
testVVConstrainedClusters(); testConstraints();
testVVConstrainedMasslessParticles(); testConstrainedClusters();
testConstrainedMasslessParticles();
testThreeParticleVirtualSite(); testThreeParticleVirtualSite();
testInitialTemperature(); testInitialTemperature();
// Thermostat tests
testHarmonicOscillator();
bool constrain;
constrain = false; testDimerBox(constrain);
constrain = true; testDimerBox(constrain);
testCheckpoints();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2019 Stanford University and the Authors. *
* Authors: Andreas Krämer and Andrew C. Simmonett *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/NoseHooverChain.h"
#include "openmm/NoseHooverIntegrator.h"
#include "openmm/Context.h"
#include "openmm/State.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h"
#include "openmm/CustomExternalForce.h"
#include "openmm/System.h"
#include "SimTKOpenMMRealType.h"
#include "sfmt/SFMT.h"
#include <iostream>
#include <sstream>
#include <iomanip>
#include <vector>
#include <algorithm>
#include <numeric>
using namespace OpenMM;
using namespace std;
void testHarmonicOscillator() {
const double mass = 1.0;
double temperature = 300;
double frequency = 1;
double mts = 1, ys = 1, chain_length = 3;
System system;
system.addParticle(mass);
vector<Vec3> positions(1);
positions[0] = Vec3(0.5,0.5,0.5);
vector<Vec3> velocities(1);
velocities[0] = Vec3(0, 0, 0);
auto harmonic_restraint = new CustomExternalForce("0.5*(x^2+y^2+z^2)");
harmonic_restraint->addParticle(0);
system.addForce(harmonic_restraint);
NoseHooverIntegrator integrator(0.001);
integrator.addThermostat(temperature, frequency, chain_length, mts, ys);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setVelocities(velocities);
double mean_temperature=0;
// equilibration
integrator.step(2000);
for (size_t i=0; i < 2500; i++){
integrator.step(10);
State state = context.getState(State::Energy | State::Positions | State::Velocities);
double kinetic_energy = state.getKineticEnergy();
double temp = kinetic_energy/(0.5*3*BOLTZ);
mean_temperature = (i*mean_temperature + temp)/(i+1);
double PE = state.getPotentialEnergy();
double time = state.getTime();
double energy = kinetic_energy + PE + integrator.computeHeatBathEnergy();
}
ASSERT_EQUAL_TOL(temperature, mean_temperature, 0.02);
}
int makeDimerBox(System& system, std::vector<Vec3>& positions, bool constrain=true, int numMolecules=20, double bondLength=0.1){
double boxLength = 2; // nm
Vec3 a(boxLength, 0.0, 0.0);
Vec3 b(0.0, boxLength, 0.0);
Vec3 c(0.0, 0.0, boxLength);
double mass = 20;
double bondForceConstant = 30000; //0.001;
int numDOF = 0;
NonbondedForce* forceField = new NonbondedForce();
HarmonicBondForce* bondForce = new HarmonicBondForce();
for(int molecule = 0; molecule < numMolecules; ++molecule) {
int particle1 = system.addParticle(mass);
int particle2 = system.addParticle(mass);
forceField->addParticle(0.0, 0.1, 1.0);
forceField->addParticle(0.0, 0.1, 1.0);
forceField->addException(particle1, particle2, 0, 0, 0);
bondForce->addBond(particle1, particle2, bondLength, bondForceConstant);
numDOF += 6;
if (constrain) {
system.addConstraint(particle1, particle2, bondLength);
numDOF -= 1;
}
}
forceField->setCutoffDistance(.99*boxLength/2);
forceField->setSwitchingDistance(.88*boxLength/2);
forceField->setUseSwitchingFunction(true);
forceField->setUseDispersionCorrection(false);
forceField->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
system.addForce(forceField);
system.addForce(bondForce);
system.setDefaultPeriodicBoxVectors(a, b, c);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numMolecules; i++) {
while (true) {
Vec3 pos = Vec3(boxLength*genrand_real2(sfmt), boxLength*genrand_real2(sfmt), boxLength*genrand_real2(sfmt));
Vec3 pos1 = pos + Vec3(0,0, bondLength/2);
Vec3 pos2 = pos + Vec3(0,0,-bondLength/2);
double minDist = 2*boxLength;
for (int j = 0; j < i; j++) {
Vec3 delta = pos1-positions[j];
minDist = std::min(minDist, sqrt(delta.dot(delta)));
delta = pos2-positions[j];
minDist = std::min(minDist, sqrt(delta.dot(delta)));
}
if (minDist > 0.15) {
positions[2*i+0] = pos1;
positions[2*i+1] = pos2;
break;
}
}
}
return numDOF;
}
void testDimerBox(bool constrain=true) {
// Check conservation of system + bath energy for a harmonic oscillator
int numMolecules = 20;
double bondLength = 0.1;
double bondLengthSquared = bondLength * bondLength;
System system;
std::vector<Vec3> positions(numMolecules*2);
int numDOF = makeDimerBox(system, positions, constrain, numMolecules, bondLength);
bool simpleConstruct = true;
double temperature = 300; // kelvin
double collisionFrequency = 200; // 1/ps
int numMTS = 3;
int numYS = 3;
int chainLength = 5;
auto integrator = simpleConstruct ? NoseHooverIntegrator(temperature, collisionFrequency, 0.001, chainLength, numMTS, numYS)
: NoseHooverIntegrator(0.001);
if (!simpleConstruct)
integrator.addThermostat(temperature, collisionFrequency, chainLength, numMTS, numYS);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setVelocitiesToTemperature(temperature);
int nSteps = 5000;
double mean_temp = 0.0;
std::vector<double> energies(nSteps);
for (int i = 0; i < nSteps; ++i) {
integrator.step(1);
State state = context.getState(State::Energy | (constrain ? State::Positions : 0));
if (constrain) {
auto positions = state.getPositions();
for(int i = 0; i < numMolecules; ++i) {
Vec3 delta = positions[2*i+1] - positions[2*i];
double dR2 = delta.dot(delta);
ASSERT_EQUAL_TOL(bondLengthSquared, dR2, 1e-4);
}
}
double KE = state.getKineticEnergy();
double PE = state.getPotentialEnergy();
double time = state.getTime();
double instantaneous_temperature = 2 * KE / (BOLTZ * numDOF);
mean_temp = (i*mean_temp + instantaneous_temperature)/(i+1);
double energy = KE + PE + integrator.computeHeatBathEnergy();
energies[i] = energy;
}
double sum = std::accumulate(energies.begin(), energies.end(), 0.0);
double mean = sum / energies.size();
double sq_sum = std::inner_product(energies.begin(), energies.end(), energies.begin(), 0.0);
double std = std::sqrt(sq_sum / energies.size() - mean * mean);
double relative_std = std / mean;
// Check mean temperature
ASSERT_USUALLY_EQUAL_TOL(temperature, mean_temp, 1e-2);
// Check fluctuation of conserved (total bath + system) energy
ASSERT_USUALLY_EQUAL_TOL(relative_std, 0, 1e-3);
}
void testCheckpoints() {
// Create a system with Drude-like particles to be thermostated as a pair, as well as another
// particle to be thermostated independently, to test all integrator features.
double timeStep = 0.001;
NoseHooverIntegrator integrator(timeStep), newIntegrator(timeStep);
System system;
double mass = 1;
system.addParticle(8*mass);
system.addParticle(mass);
system.addParticle(5*mass);
HarmonicBondForce* force = new HarmonicBondForce();
force->addBond(0, 1, 0.1, 50.0);
force->addBond(0, 2, 0.1, 50.0);
system.addForce(force);
double kineticEnergy = 1e6;
double temperature=300, collisionFrequency=1, chainLength=3, numMTS=3, numYS=3;
chainLength = 10;
integrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
newIntegrator.addSubsystemThermostat(std::vector<int>{2}, std::vector<std::pair<int,int>>{{0,1}}, temperature, collisionFrequency, temperature, collisionFrequency,
chainLength, numMTS, numYS);
Context context(system, integrator, platform);
Context newContext(system, newIntegrator, platform);
std::vector<Vec3> positions(3);
std::vector<Vec3> velocities(3);
positions[1] = {0.1, 0.0, 0.0};
velocities[1] = {0.1,0.2,-0.2};
positions[2] = {-0.1, 0.001, 0.001};
velocities[2] = {-0.1,0.2,-0.2};
context.setPositions(positions);
context.setVelocities(velocities);
// Run a short simulation and checkpoint..
integrator.step(500);
std::stringstream checkpoint;
context.createCheckpoint(checkpoint);
// Now continue the simulation
integrator.step(5);
// And try the same, starting from the checkpoint
newContext.loadCheckpoint(checkpoint);
newIntegrator.step(5);
State state1 = context.getState(State::Positions | State::Velocities);
State state2 = newContext.getState(State::Positions | State::Velocities);
ASSERT_EQUAL_VEC(state1.getPositions()[0], state2.getPositions()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getPositions()[1], state2.getPositions()[1], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[0], state2.getVelocities()[0], 1e-6);
ASSERT_EQUAL_VEC(state1.getVelocities()[1], state2.getVelocities()[1], 1e-6);
}
void testAPIChangeNumParticles() {
bool constrain = true;
int numMolecules = 20;
double bondLength = 0.1;
double bondLengthSquared = bondLength * bondLength;
System system;
std::vector<Vec3> positions(numMolecules*2);
int numDOF = makeDimerBox(system, positions, constrain, numMolecules, bondLength);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
try {
initializeTests(argc, argv);
testHarmonicOscillator();
bool constrain;
constrain = false; testDimerBox(constrain);
constrain = true; testDimerBox(constrain);
testCheckpoints();
runPlatformTests();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
...@@ -82,6 +82,22 @@ void testPeriodicSpline() { ...@@ -82,6 +82,22 @@ void testPeriodicSpline() {
ASSERT_EQUAL_TOL(sin((double)i), SplineFitter::evaluateSpline(x, y, deriv, i), 0.05); ASSERT_EQUAL_TOL(sin((double)i), SplineFitter::evaluateSpline(x, y, deriv, i), 0.05);
ASSERT_EQUAL_TOL(cos((double)i), SplineFitter::evaluateSplineDerivative(x, y, deriv, i), 0.05); ASSERT_EQUAL_TOL(cos((double)i), SplineFitter::evaluateSplineDerivative(x, y, deriv, i), 0.05);
} }
for (unsigned int i = 0; i < x.size(); i++)
x[i] = i/(x.size()-1.0);
double ya[] = {15.579, 16.235, 17.325, 18.741, 20.454, 22.517, 24.944, 27.554, 29.942, 31.657,
32.486, 32.612, 32.494, 32.532, 32.785, 32.917, 32.402, 30.842, 28.229, 24.989,
21.762, 19.074, 17.147, 15.970, 15.467, 15.579};
// scipy.interpolate.CubicSpline solution:
double sol[] = { 345.520, 271.991, 194.015, 174.449, 221.940, 250.291, 141.895, -131.620,
-447.916, -600.465, -472.723, -144.892, 137.290, 180.733, -53.971, -418.600,
-697.879, -708.635, -416.330, 22.704, 374.262, 501.498, 473.496, 417.019,
385.928, 345.520};
y.assign(begin(ya), end(ya));
SplineFitter::createPeriodicSpline(x, y, deriv);
ASSERT_EQUAL_TOL(SplineFitter::evaluateSplineDerivative(x, y, deriv, x[0]),
SplineFitter::evaluateSplineDerivative(x, y, deriv, x[x.size()-1]), 1e-6);
for (int i = 0; i < x.size(); i++)
ASSERT_EQUAL_TOL(deriv[i], sol[i], 1e-3);
} }
void test2DSpline() { void test2DSpline() {
...@@ -177,5 +193,3 @@ int main() { ...@@ -177,5 +193,3 @@ int main() {
cout << "Done" << endl; cout << "Done" << endl;
return 0; return 0;
} }
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
* * * *
* Portions copyright (c) 2014-2015 Stanford University and the Authors. * * Portions copyright (c) 2014-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: Daniel Towner *
* * * *
* Permission is hereby granted, free of charge, to any person obtaining a * * Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), * * copy of this software and associated documentation files (the "Software"), *
...@@ -35,6 +35,9 @@ ...@@ -35,6 +35,9 @@
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/internal/vectorize.h" #include "openmm/internal/vectorize.h"
#include "TestVectorizeGeneric.h"
#include <iostream> #include <iostream>
using namespace OpenMM; using namespace OpenMM;
...@@ -68,6 +71,14 @@ void testLoadStore() { ...@@ -68,6 +71,14 @@ void testLoadStore() {
ASSERT_EQUAL(i3[1], 3); ASSERT_EQUAL(i3[1], 3);
ASSERT_EQUAL(i3[2], 4); ASSERT_EQUAL(i3[2], 4);
ASSERT_EQUAL(i3[3], 5); ASSERT_EQUAL(i3[3], 5);
// Partial store of vec3 should not overwrite beyond the 3 elements.
float overwriteTest[4] = {9, 9, 9, 9};
f2.storeVec3(overwriteTest);
ASSERT_EQUAL(overwriteTest[0], f2[0]);
ASSERT_EQUAL(overwriteTest[1], f2[1]);
ASSERT_EQUAL(overwriteTest[2], f2[2]);
ASSERT_EQUAL(overwriteTest[3], 9);
} }
void testArithmetic() { void testArithmetic() {
...@@ -160,15 +171,39 @@ void testMathFunctions() { ...@@ -160,15 +171,39 @@ void testMathFunctions() {
} }
void testTranspose() { void testTranspose() {
fvec4 f1(1.0, 2.0, 3.0, 4.0); fvec4 f[4] = {
fvec4 f2(5.0, 6.0, 7.0, 8.0); {1.0, 2.0, 3.0, 4.0},
fvec4 f3(9.0, 10.0, 11.0, 12.0); {5.0, 6.0, 7.0, 8.0},
fvec4 f4(13.0, 14.0, 15.0, 16.0); {9.0, 10.0, 11.0, 12.0},
transpose(f1, f2, f3, f4); {13.0, 14.0, 15.0, 16.0}
ASSERT_VEC4_EQUAL(f1, 1.0, 5.0, 9.0, 13.0); };
ASSERT_VEC4_EQUAL(f2, 2.0, 6.0, 10.0, 14.0);
ASSERT_VEC4_EQUAL(f3, 3.0, 7.0, 11.0, 15.0); // Out-of-place tranpose into specific variables. Done before in-place transpose test.
ASSERT_VEC4_EQUAL(f4, 4.0, 8.0, 12.0, 16.0); fvec4 out0, out1, out2, out3;
transpose(f, out0, out1, out2, out3);
ASSERT_VEC4_EQUAL(out0, 1.0, 5.0, 9.0, 13.0);
ASSERT_VEC4_EQUAL(out1, 2.0, 6.0, 10.0, 14.0);
ASSERT_VEC4_EQUAL(out2, 3.0, 7.0, 11.0, 15.0);
ASSERT_VEC4_EQUAL(out3, 4.0, 8.0, 12.0, 16.0);
// In-place transpose. Done after the out-of-place transpose so avoid breaking that.
transpose(f[0], f[1], f[2], f[3]);
ASSERT_VEC4_EQUAL(f[0], 1.0, 5.0, 9.0, 13.0);
ASSERT_VEC4_EQUAL(f[1], 2.0, 6.0, 10.0, 14.0);
ASSERT_VEC4_EQUAL(f[2], 3.0, 7.0, 11.0, 15.0);
ASSERT_VEC4_EQUAL(f[3], 4.0, 8.0, 12.0, 16.0);
// Out-of-place transpose from named variables into an array.
fvec4 h[4];
fvec4 p0(0.1, 0.2, 0.3, 0.4);
fvec4 p1(0.5, 0.6, 0.7, 0.8);
fvec4 p2(0.9, 1.0, 1.1, 1.2);
fvec4 p3(1.3, 1.4, 1.5, 1.6);
transpose(p0, p1, p2, p3, h);
ASSERT_VEC4_EQUAL(h[0], 0.1, 0.5, 0.9, 1.3);
ASSERT_VEC4_EQUAL(h[1], 0.2, 0.6, 1.0, 1.4);
ASSERT_VEC4_EQUAL(h[2], 0.3, 0.7, 1.1, 1.5);
ASSERT_VEC4_EQUAL(h[3], 0.4, 0.8, 1.2, 1.6);
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -178,11 +213,10 @@ int main(int argc, char* argv[]) { ...@@ -178,11 +213,10 @@ int main(int argc, char* argv[]) {
return 0; return 0;
} }
testLoadStore(); testLoadStore();
testArithmetic();
testLogic(); testLogic();
testComparisons();
testMathFunctions(); TestFvec<fvec4>::testAll();
testTranspose();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
* * * *
* Portions copyright (c) 2014-2015 Stanford University and the Authors. * * Portions copyright (c) 2014-2015 Stanford University and the Authors. *
* Authors: Robert T. McGibbon * * Authors: Robert T. McGibbon *
* Contributors: * * Contributors: Daniel Towner *
* * * *
* Permission is hereby granted, free of charge, to any person obtaining a * * Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), * * copy of this software and associated documentation files (the "Software"), *
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "openmm/internal/vectorize8.h" #include "openmm/internal/vectorize8.h"
#include <iostream> #include <iostream>
#include "TestVectorizeGeneric.h"
#ifndef __AVX__ #ifndef __AVX__
bool isVec8Supported() { bool isVec8Supported() {
...@@ -66,32 +67,15 @@ using namespace std; ...@@ -66,32 +67,15 @@ using namespace std;
#define ASSERT_VEC8_EQUAL(found, expected0, expected1, expected2, expected3, expected4, expected5, expected6, expected7) {if (std::abs((found).lowerVec()[0]-(expected0))>1e-6 || std::abs((found).lowerVec()[1]-(expected1))>1e-6 || std::abs((found).lowerVec()[2]-(expected2))>1e-6 || std::abs((found).lowerVec()[3]-(expected3))>1e-6 || std::abs((found).upperVec()[0]-(expected4))>1e-6 || std::abs((found).upperVec()[1]-(expected5))>1e-6 || std::abs((found).upperVec()[2]-(expected6))>1e-6 || std::abs((found).upperVec()[3]-(expected7))>1e-6) {std::stringstream details; details << " Expected ("<<(expected0)<<","<<(expected1)<<","<<(expected2)<<","<<(expected3)<<","<<(expected4)<<","<<(expected5)<<","<<(expected6)<<","<<(expected7)<<"), found ("<<(found).lowerVec()[0]<<","<<(found).lowerVec()[1]<<","<<(found).lowerVec()[2]<<","<<(found).lowerVec()[3]<<","<<(found).upperVec()[0]<<","<<(found).upperVec()[1]<<","<<(found).upperVec()[2]<<","<<(found).upperVec()[3]<<")"; throwException(__FILE__, __LINE__, details.str());}}; #define ASSERT_VEC8_EQUAL(found, expected0, expected1, expected2, expected3, expected4, expected5, expected6, expected7) {if (std::abs((found).lowerVec()[0]-(expected0))>1e-6 || std::abs((found).lowerVec()[1]-(expected1))>1e-6 || std::abs((found).lowerVec()[2]-(expected2))>1e-6 || std::abs((found).lowerVec()[3]-(expected3))>1e-6 || std::abs((found).upperVec()[0]-(expected4))>1e-6 || std::abs((found).upperVec()[1]-(expected5))>1e-6 || std::abs((found).upperVec()[2]-(expected6))>1e-6 || std::abs((found).upperVec()[3]-(expected7))>1e-6) {std::stringstream details; details << " Expected ("<<(expected0)<<","<<(expected1)<<","<<(expected2)<<","<<(expected3)<<","<<(expected4)<<","<<(expected5)<<","<<(expected6)<<","<<(expected7)<<"), found ("<<(found).lowerVec()[0]<<","<<(found).lowerVec()[1]<<","<<(found).lowerVec()[2]<<","<<(found).lowerVec()[3]<<","<<(found).upperVec()[0]<<","<<(found).upperVec()[1]<<","<<(found).upperVec()[2]<<","<<(found).upperVec()[3]<<")"; throwException(__FILE__, __LINE__, details.str());}};
#define ASSERT_VEC8_EQUAL_INT(found, expected0, expected1, expected2, expected3, expected4, expected5, expected6, expected7) {if ((found).lowerVec()[0] != (expected0) || (found).lowerVec()[1] != (expected1) || (found).lowerVec()[2] != (expected2) || (found).lowerVec()[3] != (expected3) || (found).upperVec()[0] != (expected4) || (found).upperVec()[1] != (expected5) ||(found).upperVec()[2] != (expected6) || (found).upperVec()[3] != (expected7)) {std::stringstream details; details << " Expected ("<<(expected0)<<","<<(expected1)<<","<<(expected2)<<","<<(expected3)<<","<<(expected4)<<","<<(expected5)<<","<<(expected6)<<","<<(expected7)<<"), found ("<<(found).lowerVec()[0]<<","<<(found).lowerVec()[1]<<","<<(found).lowerVec()[2]<<","<<(found).lowerVec()[3]<<","<<(found).upperVec()[0]<<","<<(found).upperVec()[1]<<","<<(found).upperVec()[2]<<","<<(found).upperVec()[3]<<")"; throwException(__FILE__, __LINE__, details.str());}}; #define ASSERT_VEC8_EQUAL_INT(found, expected0, expected1, expected2, expected3, expected4, expected5, expected6, expected7) {if ((found).lowerVec()[0] != (expected0) || (found).lowerVec()[1] != (expected1) || (found).lowerVec()[2] != (expected2) || (found).lowerVec()[3] != (expected3) || (found).upperVec()[0] != (expected4) || (found).upperVec()[1] != (expected5) ||(found).upperVec()[2] != (expected6) || (found).upperVec()[3] != (expected7)) {std::stringstream details; details << " Expected ("<<(expected0)<<","<<(expected1)<<","<<(expected2)<<","<<(expected3)<<","<<(expected4)<<","<<(expected5)<<","<<(expected6)<<","<<(expected7)<<"), found ("<<(found).lowerVec()[0]<<","<<(found).lowerVec()[1]<<","<<(found).lowerVec()[2]<<","<<(found).lowerVec()[3]<<","<<(found).upperVec()[0]<<","<<(found).upperVec()[1]<<","<<(found).upperVec()[2]<<","<<(found).upperVec()[3]<<")"; throwException(__FILE__, __LINE__, details.str());}};
void testLoadStore() { void testLoadStore() {
fvec8 f1(2.0);
ivec8 i1(3); ivec8 i1(3);
ASSERT_VEC8_EQUAL(f1, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0);
ASSERT_VEC8_EQUAL_INT(i1, 3, 3, 3, 3, 3, 3, 3, 3); ASSERT_VEC8_EQUAL_INT(i1, 3, 3, 3, 3, 3, 3, 3, 3);
fvec8 f2(2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0);
ivec8 i2(2, 3, 4, 5, 6, 7, 8, 9); ivec8 i2(2, 3, 4, 5, 6, 7, 8, 9);
ASSERT_VEC8_EQUAL(f2, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0);
ASSERT_VEC8_EQUAL_INT(i2, 2, 3, 4, 5, 6, 7, 8, 9); ASSERT_VEC8_EQUAL_INT(i2, 2, 3, 4, 5, 6, 7, 8, 9);
float farray[8];
int iarray[8]; int iarray[8];
f2.store(farray);
i2.store(iarray); i2.store(iarray);
fvec8 f3(farray);
ivec8 i3(iarray); ivec8 i3(iarray);
ASSERT_VEC8_EQUAL(f3, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0);
ASSERT_VEC8_EQUAL_INT(i3, 2, 3, 4, 5, 6, 7, 8, 9); ASSERT_VEC8_EQUAL_INT(i3, 2, 3, 4, 5, 6, 7, 8, 9);
ASSERT_EQUAL(f3.lowerVec()[0], 2.5);
ASSERT_EQUAL(f3.lowerVec()[1], 3.0);
ASSERT_EQUAL(f3.lowerVec()[2], 3.5);
ASSERT_EQUAL(f3.lowerVec()[3], 4.0);
ASSERT_EQUAL(f3.upperVec()[0], 4.5);
ASSERT_EQUAL(f3.upperVec()[1], 5.0);
ASSERT_EQUAL(f3.upperVec()[2], 5.5);
ASSERT_EQUAL(f3.upperVec()[3], 6.0);
ASSERT_EQUAL(i3.lowerVec()[0], 2); ASSERT_EQUAL(i3.lowerVec()[0], 2);
ASSERT_EQUAL(i3.lowerVec()[1], 3); ASSERT_EQUAL(i3.lowerVec()[1], 3);
ASSERT_EQUAL(i3.lowerVec()[2], 4); ASSERT_EQUAL(i3.lowerVec()[2], 4);
...@@ -100,27 +84,16 @@ void testLoadStore() { ...@@ -100,27 +84,16 @@ void testLoadStore() {
ASSERT_EQUAL(i3.upperVec()[1], 7); ASSERT_EQUAL(i3.upperVec()[1], 7);
ASSERT_EQUAL(i3.upperVec()[2], 8); ASSERT_EQUAL(i3.upperVec()[2], 8);
ASSERT_EQUAL(i3.upperVec()[3], 9); ASSERT_EQUAL(i3.upperVec()[3], 9);
}
void testArithmetic() { // Partial store of vec3 should not overwrite beyond the 3 elements.
fvec8 f1(0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0); // Note that this is a fvec4 method, but is conditionally compiled for AVX so needs to be
ASSERT_VEC8_EQUAL(f1+fvec8(1, 2, 3, 4, 5, 6, 7, 8), 1.5, 3. , 4.5, 6. , 7.5, 9. , 10.5, 12.); // tested here too.
ASSERT_VEC8_EQUAL(f1-fvec8(1, 2, 3, 4, 5, 6, 7, 8), -0.5, -1. , -1.5, -2. , -2.5, -3. , -3.5, -4.); float overwriteTest[4] = {9, 9, 9, 9};
ASSERT_VEC8_EQUAL(f1*fvec8(1, 2, 3, 4, 5, 6, 7, 8), 0.5, 2. , 4.5, 8. , 12.5, 18. , 24.5, 32.); fvec4(1, 2, 3, 7777).storeVec3(overwriteTest);
ASSERT_VEC8_EQUAL(f1/fvec8(1, 2, 3, 4, 5, 6, 7, 8), 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5); ASSERT_EQUAL(overwriteTest[0], 1);
ASSERT_EQUAL(overwriteTest[1], 2);
f1 = fvec8(0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0); ASSERT_EQUAL(overwriteTest[2], 3);
f1 += fvec8(1, 2, 3, 4, 5, 6, 7, 8); ASSERT_EQUAL(overwriteTest[3], 9);
ASSERT_VEC8_EQUAL(f1, 1.5, 3. , 4.5, 6. , 7.5, 9. , 10.5, 12.);
f1 = fvec8(0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0);
f1 -= fvec8(1, 2, 3, 4, 5, 6, 7, 8);
ASSERT_VEC8_EQUAL(f1, -0.5, -1. , -1.5, -2. , -2.5, -3. , -3.5, -4.);
f1 = fvec8(0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0);
f1 *= fvec8(1, 2, 3, 4, 5, 6, 7, 8);
ASSERT_VEC8_EQUAL(f1, 0.5, 2. , 4.5, 8. , 12.5, 18. , 24.5, 32.);
f1 = fvec8(0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0);
f1 /= fvec8(1, 2, 3, 4, 5, 6, 7, 8);
ASSERT_VEC8_EQUAL(f1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5);
} }
void testLogic() { void testLogic() {
...@@ -144,76 +117,6 @@ void testLogic() { ...@@ -144,76 +117,6 @@ void testLogic() {
ASSERT_VEC8_EQUAL_INT(i1|mask, 1, allBits, allBits, 4, 5, allBits, allBits, 8); ASSERT_VEC8_EQUAL_INT(i1|mask, 1, allBits, allBits, 4, 5, allBits, allBits, 8);
} }
void testComparisons() {
fvec8 v1(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
fvec8 v2(1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5);
ASSERT_VEC8_EQUAL(blend(v1, v2,
fvec8(1.0, 1.5, 3.0, 2.2, 10.0, 10.5, 13.0, 12.2)==fvec8(1.1, 1.5, 3.0, 2.1, 10.1, 10.5, 13.0, 12.1)),
0.0, 1.5, 1.5, 0.0, 0.0, 1.5, 1.5, 0.0);
ASSERT_VEC8_EQUAL(blend(v1, v2,
fvec8(1.0, 1.5, 3.0, 2.2, 10.0, 10.5, 13.0, 12.2)!=fvec8(1.1, 1.5, 3.0, 2.1, 10.1, 10.5, 13.0, 12.1)),
1.5, 0.0, 0.0, 1.5, 1.5, 0.0, 0.0, 1.5);
ASSERT_VEC8_EQUAL(blend(v1, v2,
fvec8(1.0, 1.5, 3.0, 2.2, 10.0, 10.5, 13.0, 12.2)<fvec8(1.1, 1.5, 3.0, 2.1, 10.1, 10.5, 13.0, 12.1)),
1.5, 0.0, 0.0, 0.0, 1.5, 0.0, 0.0, 0.0);
ASSERT_VEC8_EQUAL(blend(v1, v2,
fvec8(1.0, 1.5, 3.0, 2.2, 10.0, 10.5, 13.0, 12.2)>fvec8(1.1, 1.5, 3.0, 2.1, 10.1, 10.5, 13.0, 12.1)),
0.0, 0.0, 0.0, 1.5, 0.0, 0.0, 0.0, 1.5);
ASSERT_VEC8_EQUAL(blend(v1, v2,
fvec8(1.0, 1.5, 3.0, 2.2, 10.0, 10.5, 13.0, 12.2)<=fvec8(1.1, 1.5, 3.0, 2.1, 10.1, 10.5, 13.0, 12.1)),
1.5, 1.5, 1.5, 0.0, 1.5, 1.5, 1.5, 0.0);
ASSERT_VEC8_EQUAL(blend(v1, v2,
fvec8(1.0, 1.5, 3.0, 2.2, 10.0, 10.5, 13.0, 12.2)>=fvec8(1.1, 1.5, 3.0, 2.1, 10.1, 10.5, 13.0, 12.1)),
0.0, 1.5, 1.5, 1.5, 0.0, 1.5, 1.5, 1.5);
}
void testMathFunctions() {
fvec8 f1(0.4, 1.9, -1.2, -3.8, 0.4, 1.9, -1.2, -3.8);
fvec8 f2(1.1, 1.2, 1.3, -5.0, 1.1, 1.2, 1.3, -5.0);
ASSERT_VEC8_EQUAL(floor(f1), 0.0, 1.0, -2.0, -4.0, 0.0, 1.0, -2.0, -4.0);
ASSERT_VEC8_EQUAL(ceil(f1), 1.0, 2.0, -1.0, -3.0, 1.0, 2.0, -1.0, -3.0);
ASSERT_VEC8_EQUAL(round(f1), 0.0, 2.0, -1.0, -4.0, 0.0, 2.0, -1.0, -4.0);
ASSERT_VEC8_EQUAL(abs(f1), 0.4, 1.9, 1.2, 3.8, 0.4, 1.9, 1.2, 3.8);
ASSERT_VEC8_EQUAL(min(f1, f2), 0.4, 1.2, -1.2, -5.0, 0.4, 1.2, -1.2, -5.0);
ASSERT_VEC8_EQUAL(max(f1, f2), 1.1, 1.9, 1.3, -3.8, 1.1, 1.9, 1.3, -3.8);
ASSERT_VEC8_EQUAL(sqrt(fvec8(1.5, 3.1, 4.0, 15.0, 1.5, 3.1, 4.0, 15.0)), sqrt(1.5), sqrt(3.1), sqrt(4.0), sqrt(15.0), sqrt(1.5), sqrt(3.1), sqrt(4.0), sqrt(15.0));
ASSERT_VEC8_EQUAL(rsqrt(fvec8(1.5, 3.1, 4.0, 15.0, 1.5, 3.1, 4.0, 15.0)), 1.0/sqrt(1.5), 1.0/sqrt(3.1), 1.0/sqrt(4.0), 1.0/sqrt(15.0), 1.0/sqrt(1.5), 1.0/sqrt(3.1), 1.0/sqrt(4.0), 1.0/sqrt(15.0));
ASSERT_EQUAL_TOL(f1.lowerVec()[0]*f2.lowerVec()[0]+f1.lowerVec()[1]*f2.lowerVec()[1]+f1.lowerVec()[2]*f2.lowerVec()[2]+f1.lowerVec()[3]*f2.lowerVec()[3]+f1.upperVec()[0]*f2.upperVec()[0]+f1.upperVec()[1]*f2.upperVec()[1]+f1.upperVec()[2]*f2.upperVec()[2]+f1.upperVec()[3]*f2.upperVec()[3], dot8(f1, f2), 1e-6);
ASSERT(any(f1 > 0.5));
ASSERT(!any(f1 > 2.0));
ASSERT_VEC8_EQUAL(blend(f1, f2, ivec8(-1, 0, -1, 0, -1, 0, -1, 0)), 1.1, 1.9, 1.3, -3.8, 1.1, 1.9, 1.3, -3.8);
}
void testTranspose() {
fvec4 f1(0.0, 1.0, 2.0, 3.0);
fvec4 f2(10.0, 11.0, 12.0, 13.0);
fvec4 f3(20.0, 21.0, 22.0, 23.0);
fvec4 f4(30.0, 31.0, 32.0, 33.0);
fvec4 f5(40.0, 41.0, 42.0, 43.0);
fvec4 f6(50.0, 51.0, 52.0, 53.0);
fvec4 f7(60.0, 61.0, 62.0, 63.0);
fvec4 f8(70.0, 71.0, 72.0, 73.0);
fvec8 o1, o2, o3, o4;
transpose(f1, f2, f3, f4, f5, f6, f7, f8, o1, o2, o3, o4);
ASSERT_VEC8_EQUAL(o1, 0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0);
ASSERT_VEC8_EQUAL(o2, 1.0, 11.0, 21.0, 31.0, 41.0, 51.0, 61.0, 71.0);
ASSERT_VEC8_EQUAL(o3, 2.0, 12.0, 22.0, 32.0, 42.0, 52.0, 62.0, 72.0);
ASSERT_VEC8_EQUAL(o4, 3.0, 13.0, 23.0, 33.0, 43.0, 53.0, 63.0, 73.0);
fvec4 g1, g2, g3, g4, g5, g6, g7, g8;
transpose(o1, o2, o3, o4, g1, g2, g3, g4, g5, g6, g7, g8);
ASSERT_VEC4_EQUAL(g1, 0.0, 1.0, 2.0, 3.0);
ASSERT_VEC4_EQUAL(g2, 10.0, 11.0, 12.0, 13.0);
ASSERT_VEC4_EQUAL(g3, 20.0, 21.0, 22.0, 23.0);
ASSERT_VEC4_EQUAL(g4, 30.0, 31.0, 32.0, 33.0);
ASSERT_VEC4_EQUAL(g5, 40.0, 41.0, 42.0, 43.0);
ASSERT_VEC4_EQUAL(g6, 50.0, 51.0, 52.0, 53.0);
ASSERT_VEC4_EQUAL(g7, 60.0, 61.0, 62.0, 63.0);
ASSERT_VEC4_EQUAL(g8, 70.0, 71.0, 72.0, 73.0);
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (!isVec8Supported()) { if (!isVec8Supported()) {
...@@ -221,11 +124,10 @@ int main(int argc, char* argv[]) { ...@@ -221,11 +124,10 @@ int main(int argc, char* argv[]) {
return 0; return 0;
} }
testLoadStore(); testLoadStore();
testArithmetic();
testLogic(); testLogic();
testComparisons();
testMathFunctions(); TestFvec<fvec8>::testAll();
testTranspose();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014-2020 Stanford University and the Authors. *
* Authors: Daniel Towner *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#pragma once
/**
* This tests all sizes of vectorized operations using templated test code.
*/
#include <array>
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <memory.h>
#include <sstream>
#include <typeinfo>
/**
* Return the 32-bit integer bit pattern from the given floating-point value.
*/
static int32_t floatAsIntBits(float f) {
int32_t i;
memcpy(&i, &f, 4);
return i;
}
/**
* Compare two floating-point values using units-in-last-place (ULP) as a measure of equality. Two values
* which are only a few representable values apart can be considered to be equal. Note that IEEE
* operations (add, mul, etc.) will always be exact, but sequences of operations might be more than
* a few ULP apart, but still close enough to be considered equal. ULP comparisons work at any scale of
* number, unlike an epsilon-based approach.
*/
static bool almostEqual(float a, float b) {
// Maybe they really are equal.
if (a == b)
return true;
// Infinities and NANs are never equal to anything, even other nans and infinities.
if (std::isnan(a) || std::isinf(a) ||
std::isnan(b) || std::isinf(b))
return false;
// If they are different signs then they can't be equal. For two very small denormal values they might
// be very close to each other but either side of 0, but denormals are a corner case which don't deserve
// to be equal.
if (std::signbit(a) != std::signbit(b))
return false;
// The two numbers must be valid values with the same sign, so treat then as basic integers to
// get at their ULP values. If they are only a few ULP apart, then they are essentially equal.
int32_t intDiff = std::abs(floatAsIntBits(a) - floatAsIntBits(b));
return intDiff < 4;
}
static bool exactlyEqual(float a, float b) { return a == b; }
/**
* Write the contents of the given array-like object to a stream. No formatting is applied.
*/
template<typename FVEC>
void VecToStream(std::ostream& stream, const FVEC& vec)
{
constexpr int numElements = sizeof(FVEC) / sizeof(float);
const float* vptr = (const float*)&vec;
for (int i=0; i<numElements; ++i)
stream << vptr[i] << ", ";
}
/**
* Given two vector-like objects compared each of their elements for equality. The vector objects can be
* anything which in memory is a list of 32-bit floating-point values, so SIMD vectors, C arrays or
* C++ arrays would all be valid.
*/
template<typename S, typename T>
static void checkElementsEqual(const S& computed, const T& expected,
std::function<bool(float, float)> equal_fn,
const char* file, int line) {
// Both S and T should be arrays of floats of the same length.
static_assert(sizeof(T) == sizeof(S), "Array-like elements must have the same size");
constexpr int numElements = sizeof(S) / sizeof(float);
const float* computedPtr = (const float*)&computed;
const float* expectedPtr = (const float*)&expected;
std::ostringstream details;
details << "Error during test for type " << typeid(S).name() << '\n';
bool passed = true;
for (int i=0; i<numElements; ++i)
{
if (!equal_fn(computedPtr[i], expectedPtr[i]))
passed = false;
}
if (!passed)
{
details << "Values differ. ";
VecToStream(details, computed);
details << " and ";
VecToStream(details, expected);
OpenMM::throwException(file, line, details.str());
}
}
#define ASSERT_VEC_EQUAL(computed, expected) {checkElementsEqual(computed, expected, exactlyEqual, __FILE__, __LINE__);}
#define ASSERT_VEC_ALMOST_EQUAL(computed, expected) {checkElementsEqual(computed, expected, almostEqual, __FILE__, __LINE__);}
static float getRandomFloat () {
// Between -50 and 50.
return float(rand()) / float(RAND_MAX/100.0f) - 50.0f;
}
/**
* Given an array-like memory object containing floats, apply the given function to every element.
*/
template<typename FVEC>
FVEC applyUnaryFn(const FVEC& v, std::function<float(float)> fn) {
constexpr int numElements = sizeof(FVEC) / sizeof(float);
FVEC result;
float* rp = (float*)&result;
const float* vp = (const float*)&v;
for (int i=0; i<numElements; ++i)
rp[i] = fn(vp[i]);
return result;
}
/**
* Given an array-like memory object containing floats, apply the given function to every element.
*/
template<typename FVEC>
FVEC applyBinaryFn(const FVEC& a, const FVEC& b, std::function<float(float, float)> fn) {
constexpr int numElements = sizeof(FVEC) / sizeof(float);
FVEC result;
float* rp = (float*)&result;
const float* ap = (const float*)&a;
const float* bp = (const float*)&b;
for (int i=0; i<numElements; ++i)
rp[i] = fn(ap[i], bp[i]);
return result;
}
/**
* Provide a test fixture class which underpins all verification for a given
* type of vector SIMD implementation, as well as providing common utility functions
*/
template<typename FVEC>
class TestFvec {
public:
static constexpr int numElements = sizeof(FVEC) / sizeof(float);
void testInitializers() const;
void testUnaryOps() const;
void testBinaryOps() const;
void testUtilities() const;
void testBlendAndCompare() const;
void testTranspose() const;
static void testAll() {
TestFvec<FVEC> testUnit;
testUnit.testInitializers();
testUnit.testUnaryOps();
testUnit.testBinaryOps();
testUnit.testUtilities();
testUnit.testBlendAndCompare();
testUnit.testTranspose();
}
FVEC getRandomFvec() const {
union {
FVEC v;
float f[numElements];
};
for (auto& e : f)
e = getRandomFloat();
return v;
}
};
template<typename FVEC>
void TestFvec<FVEC>::testInitializers() const {
FVEC computedZero = {};
float expectedZero[numElements] = {};
ASSERT_VEC_EQUAL(computedZero, expectedZero);
FVEC computedBroadcast(14.5f);
float expectedBroadcast[numElements];
std::fill_n(expectedBroadcast, numElements, 14.5f);
ASSERT_VEC_EQUAL(computedBroadcast, expectedBroadcast);
float expectedArray[numElements];
std::iota(expectedArray, expectedArray + numElements, 23);
FVEC computedFromLoad(expectedArray);
ASSERT_VEC_EQUAL(computedFromLoad, expectedArray);
// Gather values from a table. Variants for both one vector and two vector gathers are provided.
// The indexes to gather (multiples of 7) are also generated, along with the expected answers.
float gatherTable[2048];
for (int i=0; i<2048;++i)
gatherTable[i] = -i; // Same index to make it easy to debug, but negative to avoid copying idx.
int gatherIndexes[numElements];
float gatherIndexesAsFloat[numElements]; // Same as above, but in float format.
float expectedGather0[numElements];
float expectedGather1[numElements];
for (int i=0; i<numElements; ++i)
{
gatherIndexes[i] = i * 7;
gatherIndexesAsFloat[i] = float(gatherIndexes[i]);
expectedGather0[i] = -(i * 7);
expectedGather1[i] = -(i * 7) - 1; // Each value is one less than previous.
}
// Single value gather
FVEC computedFromGather(gatherTable, gatherIndexes);
ASSERT_VEC_EQUAL(computedFromGather, expectedGather0);
// Pair-wise vector gather. The first values should be the same as a normal gather, and the
// second are just increments from the first. Note that there musty be some suitable conversion
// from a floating-point index (i.e., an integer value in float format), and the type required
// for the second operand of gatherVecPair. gatherVecPair can then take either an actual
// float vector, or some suitable format like ivec4 or ivec8.
FVEC findex(gatherIndexesAsFloat);
FVEC p0, p1;
gatherVecPair(gatherTable, findex, p0, p1);
ASSERT_VEC_EQUAL(p0, expectedGather0);
ASSERT_VEC_EQUAL(p1, expectedGather1);
}
template<typename FVEC>
void TestFvec<FVEC>::testUnaryOps() const {
const auto v = getRandomFvec();
// Note that these are exact comparisons because all these SIMD operators are
// just applying the scalar operator, so there should be no loss of precision.
ASSERT_VEC_EQUAL(abs(v), applyUnaryFn(v, [](float x) { return std::abs(x);} ));
ASSERT_VEC_EQUAL(-v, applyUnaryFn(v, [](float x) { return 0 - x;} ));
ASSERT_VEC_EQUAL(floor(v), applyUnaryFn(v, [](float x) { return std::floor(x);} ));
ASSERT_VEC_EQUAL(ceil(v), applyUnaryFn(v, [](float x) { return std::ceil(x);} ));
ASSERT_VEC_EQUAL(round(v), applyUnaryFn(v, [](float x) { return std::round(x);} ));
// Borrow a few other functions to test sqrt neatly.
const auto positiveValue = abs(v) + 1;
ASSERT_VEC_ALMOST_EQUAL(sqrt(positiveValue * positiveValue), positiveValue);
ASSERT_VEC_ALMOST_EQUAL(rsqrt(positiveValue * positiveValue), 1.0f / abs(positiveValue));
}
template<typename FVEC>
void TestFvec<FVEC>::testBinaryOps() const {
const auto v0 = getRandomFvec();
const auto v1 = getRandomFvec();
// Note that most of these are exact comparisons because all these SIMD operators are
// just applying the scalar operator, so there should be no loss of precision. The one
// exception is division, which does often do something slightly different
// since division is an expensive operation (e.g., multiply by reciprocal).
// Binary operators.
ASSERT_VEC_EQUAL(v0 + v1, applyBinaryFn(v0, v1, std::plus<float>()));
ASSERT_VEC_EQUAL(v0 - v1, applyBinaryFn(v0, v1, std::minus<float>()));
ASSERT_VEC_EQUAL(v0 * v1, applyBinaryFn(v0, v1, std::multiplies<float>()));
ASSERT_VEC_ALMOST_EQUAL(v0 / v1, applyBinaryFn(v0, v1, std::divides<float>()));
// Assignment operators.
auto addAssign = v0;
addAssign += v1;
ASSERT_VEC_EQUAL(addAssign, applyBinaryFn(v0, v1, std::plus<float>()));
auto subAssign = v0;
subAssign -= v1;
ASSERT_VEC_EQUAL(subAssign, applyBinaryFn(v0, v1, std::minus<float>()));
auto mulAssign = v0;
mulAssign *= v1;
ASSERT_VEC_EQUAL(mulAssign, applyBinaryFn(v0, v1, std::multiplies<float>()));
auto divAssign = v0;
divAssign /= v1;
ASSERT_VEC_ALMOST_EQUAL(divAssign, applyBinaryFn(v0, v1, std::divides<float>()));
// Binary ops between SIMD and scalar.
const float f = getRandomFloat();
const FVEC fdup(f);
ASSERT_VEC_EQUAL(v0 + f, applyBinaryFn(v0, fdup, std::plus<float>()));
ASSERT_VEC_EQUAL(f + v0, applyBinaryFn(fdup, v0, std::plus<float>()));
ASSERT_VEC_EQUAL(v0 - f, applyBinaryFn(v0, fdup, std::minus<float>()));
ASSERT_VEC_EQUAL(f - v0, applyBinaryFn(fdup, v0, std::minus<float>()));
ASSERT_VEC_EQUAL(v0 * f, applyBinaryFn(v0, fdup, std::multiplies<float>()));
ASSERT_VEC_EQUAL(f * v0, applyBinaryFn(fdup, v0, std::multiplies<float>()));
ASSERT_VEC_ALMOST_EQUAL(v0 / f, applyBinaryFn(v0, fdup, std::divides<float>()));
ASSERT_VEC_ALMOST_EQUAL(f / v0, applyBinaryFn(fdup, v0, std::divides<float>()));
// Binary functions.
using std::min;
using std::max;
ASSERT_VEC_EQUAL(min(v0, v1),
applyBinaryFn(v0, v1, [](float x, float y) { return min(x, y); }));
ASSERT_VEC_EQUAL(max(v0, v1),
applyBinaryFn(v0, v1, [](float x, float y) { return max(x, y); }));
}
template<typename FVEC>
void TestFvec<FVEC>::testTranspose() const {
// A table of random data to transpose.
float table[numElements * 4];
for (auto& e : table) e = std::round(getRandomFloat());
// Load the table row data into vectors.
const auto i0 = FVEC(table + 0 * numElements);
const auto i1 = FVEC(table + 1 * numElements);
const auto i2 = FVEC(table + 2 * numElements);
const auto i3 = FVEC(table + 3 * numElements);
// Manually transpose the data.
std::array<float, numElements * 4> expectedTranspose;
for (auto r=0; r<4; ++r)
{
for (auto c=0; c<numElements; ++c)
{
expectedTranspose[c * 4 + r] = table[r * numElements + c];
}
}
fvec4 computedTranspose[numElements];
transpose(i0, i1, i2, i3, computedTranspose);
ASSERT_VEC_EQUAL(computedTranspose, expectedTranspose);
FVEC o0, o1, o2, o3;
transpose(computedTranspose, o0, o1, o2, o3);
ASSERT_VEC_EQUAL(i0, o0);
ASSERT_VEC_EQUAL(i1, o1);
ASSERT_VEC_EQUAL(i2, o2);
ASSERT_VEC_EQUAL(i3, o3);
}
template<typename FVEC>
void TestFvec<FVEC>::testBlendAndCompare() const {
const FVEC zero = {};
const FVEC allOne(1.0f);
const FVEC allTwo(2.0f);
// Note that different targets use different types of mask, so rather than checking
// the mask directly, instead check the output of using the mask as a blend to provide
// an indirect test.
const auto maskNone = FVEC::expandBitsToMask(0);
ASSERT_VEC_EQUAL(blend(allOne, allTwo, maskNone), allOne);
ASSERT_VEC_EQUAL(blendZero(allOne, maskNone), zero);
const auto maskAll = FVEC::expandBitsToMask(-1);
ASSERT_VEC_EQUAL(blend(allOne, allTwo, maskAll), allTwo);
ASSERT_VEC_EQUAL(blendZero(allOne, maskAll), allOne);
// Repeating pattern big enough to do most SIMD lengths.
const int bitmask = 0b1100001101101001;
const auto maskSome = FVEC::expandBitsToMask(bitmask);
float expectedMaskSome[numElements];
float expectedZeroMaskSome[numElements];
for (int i=0; i<numElements; ++i)
{
expectedMaskSome[i] = (bitmask & (1 << i)) ? 2.0f : 1.0f;
expectedZeroMaskSome[i] = (bitmask & (1 << i)) ? 2.0f : 0.0f;
}
ASSERT_VEC_EQUAL(blend(allOne, allTwo, maskSome), expectedMaskSome);
ASSERT_VEC_EQUAL(blendZero(allTwo, maskSome), expectedZeroMaskSome);
// Test comparisons too, using random numbers, and then blending in either 0 or 1.
const auto v0 = getRandomFvec();
const auto v1 = getRandomFvec();
ASSERT_VEC_EQUAL(blend(allOne, allTwo, v0 < v1),
applyBinaryFn(v0, v1, [](float x, float y) { return x < y ? 2.0f : 1.0f; }));
ASSERT_VEC_EQUAL(blend(allOne, allTwo, v0 <= v1),
applyBinaryFn(v0, v1, [](float x, float y) { return x <= y ? 2.0f : 1.0f; }));
ASSERT_VEC_EQUAL(blend(allOne, allTwo, v0 <= v0), allTwo);
ASSERT_VEC_EQUAL(blend(allOne, allTwo, v0 > v1),
applyBinaryFn(v0, v1, [](float x, float y) { return x > y ? 2.0f : 1.0f; }));
ASSERT_VEC_EQUAL(blend(allOne, allTwo, v0 >= v1),
applyBinaryFn(v0, v1, [](float x, float y) { return x >= y ? 2.0f : 1.0f; }));
ASSERT_VEC_EQUAL(blend(allOne, allTwo, v0 >= v0), allTwo);
}
template<typename FVEC>
void TestFvec<FVEC>::testUtilities() const {
/** Use rounded (i.e., integer) values for the reductions. Reduction operations are very sensitive
* to ordering. The correct result is found by sorting values into ascending order to ensure that
* similar sized numbers are accumulated earlier than less similar numbers. If completely random
* numbers were used, this effect would show up here, making it more a test of what random numbers
* you got, than of the code itself. By rounding to integers, the numbers will behave sanely for the
* reduction, meaning it is a test of the reduction, and not of the format.
*/
const auto v0 = round(getRandomFvec());
const auto v1 = round(getRandomFvec());
const auto v2 = round(getRandomFvec());
const float* v0p = (const float*)&v0;
const float* v1p = (const float*)&v1;
const float* v2p = (const float*)&v2;
const auto expectedRedAddV0 = std::accumulate(v0p, v0p + numElements, 0.0f);
const auto expectedRedAddV1 = std::accumulate(v1p, v1p + numElements, 0.0f);
const auto expectedRedAddV2 = std::accumulate(v2p, v2p + numElements, 0.0f);
ASSERT_VEC_EQUAL(reduceAdd(v0), expectedRedAddV0);
// Reduction of three vectors by addition into a single 3-element vector. Note that the final element
// of the reduction is undefined, so the expected value copies over whatever that undefined value is.
const auto computedRed3 = reduceToVec3(v0, v1, v2);
const auto expectedRed3 = fvec4(expectedRedAddV0, expectedRedAddV1, expectedRedAddV2, computedRed3[3]);
ASSERT_VEC_EQUAL(computedRed3, expectedRed3);
}
\ No newline at end of file
...@@ -108,15 +108,12 @@ if not release: ...@@ -108,15 +108,12 @@ if not release:
if not IS_RELEASED: if not IS_RELEASED:
full_version += '.dev-' + git_revision[:7] full_version += '.dev-' + git_revision[:7]
a = open(filename, 'w') with open(filename, 'w') as a:
try:
a.write(cnt % {'version': version, a.write(cnt % {'version': version,
'full_version' : full_version, 'full_version' : full_version,
'git_revision' : git_revision, 'git_revision' : git_revision,
'isrelease': str(IS_RELEASED), 'isrelease': str(IS_RELEASED),
'path': os.getenv('OPENMM_LIB_PATH')}) 'path': os.getenv('OPENMM_LIB_PATH')})
finally:
a.close()
def buildKeywordDictionary(major_version_num=MAJOR_VERSION_NUM, def buildKeywordDictionary(major_version_num=MAJOR_VERSION_NUM,
...@@ -249,5 +246,3 @@ def main(): ...@@ -249,5 +246,3 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -18,7 +18,7 @@ if sys.platform == 'win32': ...@@ -18,7 +18,7 @@ if sys.platform == 'win32':
from simtk.openmm.openmm import * from simtk.openmm.openmm import *
from simtk.openmm.vec3 import Vec3 from simtk.openmm.vec3 import Vec3
from simtk.openmm.mtsintegrator import MTSIntegrator from simtk.openmm.mtsintegrator import MTSIntegrator, MTSLangevinIntegrator
from simtk.openmm.amd import AMDIntegrator, AMDForceGroupIntegrator, DualAMDIntegrator from simtk.openmm.amd import AMDIntegrator, AMDForceGroupIntegrator, DualAMDIntegrator
if os.getenv('OPENMM_PLUGIN_DIR') is None and os.path.isdir(version.openmm_library_path): if os.getenv('OPENMM_PLUGIN_DIR') is None and os.path.isdir(version.openmm_library_path):
...@@ -30,3 +30,7 @@ if sys.platform == 'win32': ...@@ -30,3 +30,7 @@ if sys.platform == 'win32':
os.environ['PATH'] = _path os.environ['PATH'] = _path
del _path del _path
__version__ = Platform.getOpenMMVersion() __version__ = Platform.getOpenMMVersion()
class OpenMMException(Exception):
"""This is the class used for all exceptions thrown by the C++ library."""
pass
...@@ -272,6 +272,9 @@ class CharmmPsfFile(object): ...@@ -272,6 +272,9 @@ class CharmmPsfFile(object):
drudepair_list.append([min(id1,id2), max(id1,id2)]) drudepair_list.append([min(id1,id2), max(id1,id2)])
elif (atom_list[id1].name[0:2]=='LP' or atom_list[id2].name[0:2]=='LP' or atom_list[id1].name=='OM' or atom_list[id2].name=='OM'): elif (atom_list[id1].name[0:2]=='LP' or atom_list[id2].name[0:2]=='LP' or atom_list[id1].name=='OM' or atom_list[id2].name=='OM'):
pass pass
# Ignore H-H bond in water if present
elif atom_list[id1].name[0]=='H' and atom_list[id2].name[0]=='H' and (atom_list[id1].residue.resname in WATNAMES):
pass
else: else:
bond_list.append(Bond(atom_list[id1], atom_list[id2])) bond_list.append(Bond(atom_list[id1], atom_list[id2]))
bond_list.changed = False bond_list.changed = False
...@@ -804,7 +807,8 @@ class CharmmPsfFile(object): ...@@ -804,7 +807,8 @@ class CharmmPsfFile(object):
ewaldErrorTolerance=0.0005, ewaldErrorTolerance=0.0005,
flexibleConstraints=True, flexibleConstraints=True,
verbose=False, verbose=False,
gbsaModel=None): gbsaModel=None,
drudeMass=0.4*u.amu):
"""Construct an OpenMM System representing the topology described by the """Construct an OpenMM System representing the topology described by the
prmtop file. You MUST have loaded a parameter set into this PSF before prmtop file. You MUST have loaded a parameter set into this PSF before
calling createSystem. If not, AttributeError will be raised. ValueError calling createSystem. If not, AttributeError will be raised. ValueError
...@@ -862,6 +866,9 @@ class CharmmPsfFile(object): ...@@ -862,6 +866,9 @@ class CharmmPsfFile(object):
gbsaModel : str=None gbsaModel : str=None
Can be ACE (to use the ACE solvation model) or None. Other values Can be ACE (to use the ACE solvation model) or None. Other values
raise a ValueError raise a ValueError
drudeMass : mass=0.4*amu
The mass to use for Drude particles. Any mass added to a Drude particle is
subtracted from its parent atom to keep their total mass the same.
""" """
# Load the parameter set # Load the parameter set
self.loadParameters(params) self.loadParameters(params)
...@@ -1384,9 +1391,7 @@ class CharmmPsfFile(object): ...@@ -1384,9 +1391,7 @@ class CharmmPsfFile(object):
# Add excluded atoms # Add excluded atoms
# Drude and lonepairs will be excluded based on their parent atoms # Drude and lonepairs will be excluded based on their parent atoms
parent_exclude_list=[] parent_exclude_list=[[] for _ in self.atom_list]
for atom in self.atom_list:
parent_exclude_list.append([])
for lpsite in self.lonepair_list: for lpsite in self.lonepair_list:
idx = lpsite[1] idx = lpsite[1]
idxa = lpsite[0] idxa = lpsite[0]
...@@ -1462,6 +1467,17 @@ class CharmmPsfFile(object): ...@@ -1462,6 +1467,17 @@ class CharmmPsfFile(object):
drude2 = ia2 + 1 drude2 = ia2 + 1
drudeforce.addScreenedPair(particleMap[drude1], particleMap[drude2], thole1+thole2) drudeforce.addScreenedPair(particleMap[drude1], particleMap[drude2], thole1+thole2)
# Set the masses of Drude particles.
if not u.is_quantity(drudeMass):
drudeMass *= u.dalton
for i in range(drudeforce.getNumParticles()):
params = drudeforce.getParticleParameters(i)
particle = params[0]
parent = params[1]
transferMass = drudeMass-system.getParticleMass(particle)
system.setParticleMass(particle, drudeMass)
system.setParticleMass(parent, system.getParticleMass(parent)-transferMass)
# If we needed a CustomNonbondedForce, map all of the exceptions from # If we needed a CustomNonbondedForce, map all of the exceptions from
# the NonbondedForce to the CustomNonbondedForce # the NonbondedForce to the CustomNonbondedForce
if has_nbfix_terms: if has_nbfix_terms:
......
...@@ -360,6 +360,9 @@ class ForceField(object): ...@@ -360,6 +360,9 @@ class ForceField(object):
for bond in patch.findall('RemoveExternalBond'): for bond in patch.findall('RemoveExternalBond'):
atom = ForceField._PatchAtomData(bond.attrib['atomName']) atom = ForceField._PatchAtomData(bond.attrib['atomName'])
patchData.deletedExternalBonds.append(atom) patchData.deletedExternalBonds.append(atom)
atomIndices = dict((atom.name, i) for i, atom in enumerate(patchData.addedAtoms[atomDescription.residue]+patchData.changedAtoms[atomDescription.residue]))
for site in patch.findall('VirtualSite'):
patchData.virtualSites[atomDescription.residue].append(ForceField._VirtualSiteData(site, atomIndices))
for residue in patch.findall('ApplyToResidue'): for residue in patch.findall('ApplyToResidue'):
name = residue.attrib['name'] name = residue.attrib['name']
if ':' in name: if ':' in name:
...@@ -555,20 +558,30 @@ class ForceField(object): ...@@ -555,20 +558,30 @@ class ForceField(object):
class _SystemData(object): class _SystemData(object):
"""Inner class used to encapsulate data about the system being created.""" """Inner class used to encapsulate data about the system being created."""
def __init__(self): def __init__(self, topology):
self.atomType = {} self.atomType = {}
self.atomParameters = {} self.atomParameters = {}
self.atomTemplateIndexes = {} self.atomTemplateIndexes = {}
self.atoms = [] self.atoms = list(topology.atoms())
self.excludeAtomWith = [] self.excludeAtomWith = [[] for a in self.atoms]
self.virtualSites = {} self.virtualSites = {}
self.bonds = [] self.bonds = [ForceField._BondData(bond[0].index, bond[1].index) for bond in topology.bonds()]
self.angles = [] self.angles = []
self.propers = [] self.propers = []
self.impropers = [] self.impropers = []
self.atomBonds = [] self.atomBonds = [[] for a in self.atoms]
self.isAngleConstrained = [] self.isAngleConstrained = []
self.constraints = {} self.constraints = {}
self.bondedToAtom = [set() for a in self.atoms]
# Record which atoms are bonded to each other atom
for i in range(len(self.bonds)):
bond = self.bonds[i]
self.bondedToAtom[bond.atom1].add(bond.atom2)
self.bondedToAtom[bond.atom2].add(bond.atom1)
self.atomBonds[bond.atom1].append(i)
self.atomBonds[bond.atom2].append(i)
def addConstraint(self, system, atom1, atom2, distance): def addConstraint(self, system, atom1, atom2, distance):
"""Add a constraint to the system, avoiding duplicate constraints.""" """Add a constraint to the system, avoiding duplicate constraints."""
...@@ -732,6 +745,7 @@ class ForceField(object): ...@@ -732,6 +745,7 @@ class ForceField(object):
self.addedExternalBonds = [] self.addedExternalBonds = []
self.deletedExternalBonds = [] self.deletedExternalBonds = []
self.allAtomNames = set() self.allAtomNames = set()
self.virtualSites = [[] for i in range(numResidues)]
def createPatchedTemplates(self, templates): def createPatchedTemplates(self, templates):
"""Apply this patch to a set of templates, creating new modified ones.""" """Apply this patch to a set of templates, creating new modified ones."""
...@@ -793,6 +807,16 @@ class ForceField(object): ...@@ -793,6 +807,16 @@ class ForceField(object):
newTemplate.addExternalBondByName(atom2.name) newTemplate.addExternalBondByName(atom2.name)
for atom in self.addedExternalBonds: for atom in self.addedExternalBonds:
newTemplate.addExternalBondByName(atom.name) newTemplate.addExternalBondByName(atom.name)
# Add new virtual sites.
indexMap = dict((i, newAtomIndex[atom.name]) for i, atom in enumerate(self.addedAtoms[index]+self.changedAtoms[index]))
for site in self.virtualSites[index]:
newSite = deepcopy(site)
newSite.index = indexMap[site.index]
newSite.atoms = [indexMap[i] for i in site.atoms]
newTemplate.virtualSites = [site for site in newTemplate.virtualSites if site.index != newSite.index]
newTemplate.virtualSites.append(newSite)
return newTemplates return newTemplates
class _PatchAtomData(object): class _PatchAtomData(object):
...@@ -885,8 +909,8 @@ class ForceField(object): ...@@ -885,8 +909,8 @@ class ForceField(object):
raise ValueError('%s: No parameters defined for atom type %s' % (self.forceName, t)) raise ValueError('%s: No parameters defined for atom type %s' % (self.forceName, t))
def _getResidueTemplateMatches(self, res, bondedToAtom, templateSignatures=None, ignoreExternalBonds=False): def _getResidueTemplateMatches(self, res, bondedToAtom, templateSignatures=None, ignoreExternalBonds=False, ignoreExtraParticles=False):
"""Return the residue template matches, or None if none are found. """Return the templates that match a residue, or None if none are found.
Parameters Parameters
---------- ----------
...@@ -912,7 +936,7 @@ class ForceField(object): ...@@ -912,7 +936,7 @@ class ForceField(object):
if signature in templateSignatures: if signature in templateSignatures:
allMatches = [] allMatches = []
for t in templateSignatures[signature]: for t in templateSignatures[signature]:
match = compiled.matchResidueToTemplate(res, t, bondedToAtom, ignoreExternalBonds) match = compiled.matchResidueToTemplate(res, t, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
if match is not None: if match is not None:
allMatches.append((t, match)) allMatches.append((t, match))
if len(allMatches) == 1: if len(allMatches) == 1:
...@@ -1060,7 +1084,7 @@ class ForceField(object): ...@@ -1060,7 +1084,7 @@ class ForceField(object):
def createSystem(self, topology, nonbondedMethod=NoCutoff, nonbondedCutoff=1.0*unit.nanometer, def createSystem(self, topology, nonbondedMethod=NoCutoff, nonbondedCutoff=1.0*unit.nanometer,
constraints=None, rigidWater=None, removeCMMotion=True, hydrogenMass=None, residueTemplates=dict(), constraints=None, rigidWater=None, removeCMMotion=True, hydrogenMass=None, residueTemplates=dict(),
ignoreExternalBonds=False, switchDistance=None, flexibleConstraints=False, **args): ignoreExternalBonds=False, switchDistance=None, flexibleConstraints=False, drudeMass=0.4*unit.amu, **args):
"""Construct an OpenMM System representing a Topology with this force field. """Construct an OpenMM System representing a Topology with this force field.
Parameters Parameters
...@@ -1102,6 +1126,9 @@ class ForceField(object): ...@@ -1102,6 +1126,9 @@ class ForceField(object):
Lennard-Jones interactions. If this is None, no switching function will be used. Lennard-Jones interactions. If this is None, no switching function will be used.
flexibleConstraints : boolean=False flexibleConstraints : boolean=False
If True, parameters for constrained degrees of freedom will be added to the System If True, parameters for constrained degrees of freedom will be added to the System
drudeMass : mass=0.4*amu
The mass to use for Drude particles. Any mass added to a Drude particle is
subtracted from its parent atom to keep their total mass the same.
args args
Arbitrary additional keyword arguments may also be specified. Arbitrary additional keyword arguments may also be specified.
This allows extra parameters to be specified that are specific to This allows extra parameters to be specified that are specific to
...@@ -1114,83 +1141,21 @@ class ForceField(object): ...@@ -1114,83 +1141,21 @@ class ForceField(object):
""" """
args['switchDistance'] = switchDistance args['switchDistance'] = switchDistance
args['flexibleConstraints'] = flexibleConstraints args['flexibleConstraints'] = flexibleConstraints
data = ForceField._SystemData() args['drudeMass'] = drudeMass
data.atoms = list(topology.atoms()) data = ForceField._SystemData(topology)
for atom in data.atoms:
data.excludeAtomWith.append([])
rigidResidue = [False]*topology.getNumResidues() rigidResidue = [False]*topology.getNumResidues()
# Make a list of all bonds
for bond in topology.bonds():
data.bonds.append(ForceField._BondData(bond[0].index, bond[1].index))
# Record which atoms are bonded to each other atom
bondedToAtom = []
for i in range(len(data.atoms)):
bondedToAtom.append(set())
data.atomBonds.append([])
for i in range(len(data.bonds)):
bond = data.bonds[i]
bondedToAtom[bond.atom1].add(bond.atom2)
bondedToAtom[bond.atom2].add(bond.atom1)
data.atomBonds[bond.atom1].append(i)
data.atomBonds[bond.atom2].append(i)
# Find the template matching each residue and assign atom types. # Find the template matching each residue and assign atom types.
unmatchedResidues = [] templateForResidue = self._matchAllResiduesToTemplates(data, topology, residueTemplates, ignoreExternalBonds)
for chain in topology.chains(): for res in topology.residues():
for res in chain.residues(): if res.name == 'HOH':
if res in residueTemplates: # Determine whether this should be a rigid water.
tname = residueTemplates[res]
template = self._templates[tname] if rigidWater is None:
matches = compiled.matchResidueToTemplate(res, template, bondedToAtom, ignoreExternalBonds) rigidResidue[res.index] = templateForResidue[res.index].rigidWater
if matches is None: elif rigidWater:
raise Exception('User-supplied template %s does not match the residue %d (%s)' % (tname, res.index+1, res.name)) rigidResidue[res.index] = True
else:
# Attempt to match one of the existing templates.
[template, matches] = self._getResidueTemplateMatches(res, bondedToAtom, ignoreExternalBonds=ignoreExternalBonds)
if matches is None:
unmatchedResidues.append(res)
else:
data.recordMatchedAtomParameters(res, template, matches)
if res.name == 'HOH':
# Determine whether this should be a rigid water.
if rigidWater is None and template is not None:
rigidResidue[res.index] = template.rigidWater
elif rigidWater:
rigidResidue[res.index] = True
# Try to apply patches to find matches for any unmatched residues.
if len(unmatchedResidues) > 0:
unmatchedResidues = _applyPatchesToMatchResidues(self, data, unmatchedResidues, bondedToAtom, ignoreExternalBonds)
# If we still haven't found a match for a residue, attempt to use residue template generators to create
# new templates (and potentially atom types/parameters).
for res in unmatchedResidues:
# A template might have been generated on an earlier iteration of this loop.
[template, matches] = self._getResidueTemplateMatches(res, bondedToAtom, ignoreExternalBonds=ignoreExternalBonds)
if matches is None:
# Try all generators.
for generator in self._templateGenerators:
if generator(self, res):
# This generator has registered a new residue template that should match.
[template, matches] = self._getResidueTemplateMatches(res, bondedToAtom, ignoreExternalBonds=ignoreExternalBonds)
if matches is None:
# Something went wrong because the generated template does not match the residue signature.
raise Exception('The residue handler %s indicated it had correctly parameterized residue %s, but the generated template did not match the residue signature.' % (generator.__class__.__name__, str(res)))
else:
# We successfully generated a residue template. Break out of the for loop.
break
if matches is None:
raise ValueError('No template found for residue %d (%s). %s' % (res.index+1, res.name, _findMatchErrors(self, res)))
else:
data.recordMatchedAtomParameters(res, template, matches)
# Create the System and add atoms # Create the System and add atoms
...@@ -1236,13 +1201,13 @@ class ForceField(object): ...@@ -1236,13 +1201,13 @@ class ForceField(object):
uniqueAngles = set() uniqueAngles = set()
for bond in data.bonds: for bond in data.bonds:
for atom in bondedToAtom[bond.atom1]: for atom in data.bondedToAtom[bond.atom1]:
if atom != bond.atom2: if atom != bond.atom2:
if atom < bond.atom2: if atom < bond.atom2:
uniqueAngles.add((atom, bond.atom1, bond.atom2)) uniqueAngles.add((atom, bond.atom1, bond.atom2))
else: else:
uniqueAngles.add((bond.atom2, bond.atom1, atom)) uniqueAngles.add((bond.atom2, bond.atom1, atom))
for atom in bondedToAtom[bond.atom2]: for atom in data.bondedToAtom[bond.atom2]:
if atom != bond.atom1: if atom != bond.atom1:
if atom > bond.atom1: if atom > bond.atom1:
uniqueAngles.add((bond.atom1, bond.atom2, atom)) uniqueAngles.add((bond.atom1, bond.atom2, atom))
...@@ -1254,13 +1219,13 @@ class ForceField(object): ...@@ -1254,13 +1219,13 @@ class ForceField(object):
uniquePropers = set() uniquePropers = set()
for angle in data.angles: for angle in data.angles:
for atom in bondedToAtom[angle[0]]: for atom in data.bondedToAtom[angle[0]]:
if atom not in angle: if atom not in angle:
if atom < angle[2]: if atom < angle[2]:
uniquePropers.add((atom, angle[0], angle[1], angle[2])) uniquePropers.add((atom, angle[0], angle[1], angle[2]))
else: else:
uniquePropers.add((angle[2], angle[1], angle[0], atom)) uniquePropers.add((angle[2], angle[1], angle[0], atom))
for atom in bondedToAtom[angle[2]]: for atom in data.bondedToAtom[angle[2]]:
if atom not in angle: if atom not in angle:
if atom > angle[0]: if atom > angle[0]:
uniquePropers.add((angle[0], angle[1], angle[2], atom)) uniquePropers.add((angle[0], angle[1], angle[2], atom))
...@@ -1270,8 +1235,8 @@ class ForceField(object): ...@@ -1270,8 +1235,8 @@ class ForceField(object):
# Make a list of all unique improper torsions # Make a list of all unique improper torsions
for atom in range(len(bondedToAtom)): for atom in range(len(data.bondedToAtom)):
bondedTo = bondedToAtom[atom] bondedTo = data.bondedToAtom[atom]
if len(bondedTo) > 2: if len(bondedTo) > 2:
for subset in itertools.combinations(bondedTo, 3): for subset in itertools.combinations(bondedTo, 3):
data.impropers.append((atom, subset[0], subset[1], subset[2])) data.impropers.append((atom, subset[0], subset[1], subset[2]))
...@@ -1350,6 +1315,60 @@ class ForceField(object): ...@@ -1350,6 +1315,60 @@ class ForceField(object):
return sys return sys
def _matchAllResiduesToTemplates(self, data, topology, residueTemplates, ignoreExternalBonds, ignoreExtraParticles=False, recordParameters=True):
"""Return a list of which template matches each residue in the topology, and assign atom types."""
templateForResidue = [None]*topology.getNumResidues()
unmatchedResidues = []
for chain in topology.chains():
for res in chain.residues():
if res in residueTemplates:
tname = residueTemplates[res]
template = self._templates[tname]
matches = compiled.matchResidueToTemplate(res, template, data.bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
if matches is None:
raise Exception('User-supplied template %s does not match the residue %d (%s)' % (tname, res.index+1, res.name))
else:
# Attempt to match one of the existing templates.
[template, matches] = self._getResidueTemplateMatches(res, data.bondedToAtom, ignoreExternalBonds=ignoreExternalBonds, ignoreExtraParticles=ignoreExtraParticles)
if matches is None:
unmatchedResidues.append(res)
else:
if recordParameters:
data.recordMatchedAtomParameters(res, template, matches)
templateForResidue[res.index] = template
# Try to apply patches to find matches for any unmatched residues.
if len(unmatchedResidues) > 0:
unmatchedResidues = _applyPatchesToMatchResidues(self, data, unmatchedResidues, templateForResidue, data.bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
# If we still haven't found a match for a residue, attempt to use residue template generators to create
# new templates (and potentially atom types/parameters).
for res in unmatchedResidues:
# A template might have been generated on an earlier iteration of this loop.
[template, matches] = self._getResidueTemplateMatches(res, data.bondedToAtom, ignoreExternalBonds=ignoreExternalBonds, ignoreExtraParticles=ignoreExtraParticles)
if matches is None:
# Try all generators.
for generator in self._templateGenerators:
if generator(self, res):
# This generator has registered a new residue template that should match.
[template, matches] = self._getResidueTemplateMatches(res, data.bondedToAtom, ignoreExternalBonds=ignoreExternalBonds, ignoreExtraParticles=ignoreExtraParticles)
if matches is None:
# Something went wrong because the generated template does not match the residue signature.
raise Exception('The residue handler %s indicated it had correctly parameterized residue %s, but the generated template did not match the residue signature.' % (generator.__class__.__name__, str(res)))
else:
# We successfully generated a residue template. Break out of the for loop.
break
if matches is None:
raise ValueError('No template found for residue %d (%s). %s' % (res.index+1, res.name, _findMatchErrors(self, res)))
else:
if recordParameters:
data.recordMatchedAtomParameters(res, template, matches)
templateForResidue[res.index] = template
return templateForResidue
def _findBondsForExclusions(data, sys): def _findBondsForExclusions(data, sys):
"""Create a list of bonds to use when identifying exclusions.""" """Create a list of bonds to use when identifying exclusions."""
bondIndices = [] bondIndices = []
...@@ -1471,7 +1490,7 @@ def _createResidueSignature(elements): ...@@ -1471,7 +1490,7 @@ def _createResidueSignature(elements):
return s return s
def _applyPatchesToMatchResidues(forcefield, data, residues, bondedToAtom, ignoreExternalBonds): def _applyPatchesToMatchResidues(forcefield, data, residues, templateForResidue, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles):
"""Try to apply patches to find matches for residues.""" """Try to apply patches to find matches for residues."""
# Start by creating all templates than can be created by applying a combination of one-residue patches # Start by creating all templates than can be created by applying a combination of one-residue patches
# to a single template. The number of these is usually not too large, and they often cover a large fraction # to a single template. The number of these is usually not too large, and they often cover a large fraction
...@@ -1497,11 +1516,12 @@ def _applyPatchesToMatchResidues(forcefield, data, residues, bondedToAtom, ignor ...@@ -1497,11 +1516,12 @@ def _applyPatchesToMatchResidues(forcefield, data, residues, bondedToAtom, ignor
unmatchedResidues = [] unmatchedResidues = []
for res in residues: for res in residues:
[template, matches] = forcefield._getResidueTemplateMatches(res, bondedToAtom, patchedTemplateSignatures, ignoreExternalBonds) [template, matches] = forcefield._getResidueTemplateMatches(res, bondedToAtom, patchedTemplateSignatures, ignoreExternalBonds, ignoreExtraParticles)
if matches is None: if matches is None:
unmatchedResidues.append(res) unmatchedResidues.append(res)
else: else:
data.recordMatchedAtomParameters(res, template, matches) data.recordMatchedAtomParameters(res, template, matches)
templateForResidue[res.index] = template
if len(unmatchedResidues) == 0: if len(unmatchedResidues) == 0:
return [] return []
...@@ -1549,7 +1569,7 @@ def _applyPatchesToMatchResidues(forcefield, data, residues, bondedToAtom, ignor ...@@ -1549,7 +1569,7 @@ def _applyPatchesToMatchResidues(forcefield, data, residues, bondedToAtom, ignor
for patchName in patches: for patchName in patches:
patch = forcefield._patches[patchName] patch = forcefield._patches[patchName]
if patch.numResidues == clusterSize: if patch.numResidues == clusterSize:
matchedClusters = _matchToMultiResiduePatchedTemplates(data, clusters, patch, patches[patchName], bondedToAtom, ignoreExternalBonds) matchedClusters = _matchToMultiResiduePatchedTemplates(data, clusters, patch, patches[patchName], bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
for cluster in matchedClusters: for cluster in matchedClusters:
for residue in cluster: for residue in cluster:
unmatchedResidues.remove(residue) unmatchedResidues.remove(residue)
...@@ -1599,21 +1619,21 @@ def _generatePatchedSingleResidueTemplates(template, patches, index, newTemplate ...@@ -1599,21 +1619,21 @@ def _generatePatchedSingleResidueTemplates(template, patches, index, newTemplate
_generatePatchedSingleResidueTemplates(patchedTemplate, patches, index+1, newTemplates, newAlteredAtoms) _generatePatchedSingleResidueTemplates(patchedTemplate, patches, index+1, newTemplates, newAlteredAtoms)
def _matchToMultiResiduePatchedTemplates(data, clusters, patch, residueTemplates, bondedToAtom, ignoreExternalBonds): def _matchToMultiResiduePatchedTemplates(data, clusters, patch, residueTemplates, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles):
"""Apply a multi-residue patch to templates, then try to match them against clusters of residues.""" """Apply a multi-residue patch to templates, then try to match them against clusters of residues."""
matchedClusters = [] matchedClusters = []
selectedTemplates = [None]*patch.numResidues selectedTemplates = [None]*patch.numResidues
_applyMultiResiduePatch(data, clusters, patch, residueTemplates, selectedTemplates, 0, matchedClusters, bondedToAtom, ignoreExternalBonds) _applyMultiResiduePatch(data, clusters, patch, residueTemplates, selectedTemplates, 0, matchedClusters, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
return matchedClusters return matchedClusters
def _applyMultiResiduePatch(data, clusters, patch, candidateTemplates, selectedTemplates, index, matchedClusters, bondedToAtom, ignoreExternalBonds): def _applyMultiResiduePatch(data, clusters, patch, candidateTemplates, selectedTemplates, index, matchedClusters, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles):
"""This is called recursively to apply a multi-residue patch to all possible combinations of templates.""" """This is called recursively to apply a multi-residue patch to all possible combinations of templates."""
if index < patch.numResidues: if index < patch.numResidues:
for template in candidateTemplates[index]: for template in candidateTemplates[index]:
selectedTemplates[index] = template selectedTemplates[index] = template
_applyMultiResiduePatch(data, clusters, patch, candidateTemplates, selectedTemplates, index+1, matchedClusters, bondedToAtom, ignoreExternalBonds) _applyMultiResiduePatch(data, clusters, patch, candidateTemplates, selectedTemplates, index+1, matchedClusters, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
else: else:
# We're at the deepest level of the recursion. We've selected a template for each residue, so apply the patch, # We're at the deepest level of the recursion. We've selected a template for each residue, so apply the patch,
# then try to match it against clusters. # then try to match it against clusters.
...@@ -1629,7 +1649,7 @@ def _applyMultiResiduePatch(data, clusters, patch, candidateTemplates, selectedT ...@@ -1629,7 +1649,7 @@ def _applyMultiResiduePatch(data, clusters, patch, candidateTemplates, selectedT
for residues in itertools.permutations(cluster): for residues in itertools.permutations(cluster):
residueMatches = [] residueMatches = []
for residue, template in zip(residues, patchedTemplates): for residue, template in zip(residues, patchedTemplates):
matches = compiled.matchResidueToTemplate(residue, template, bondedToAtom, ignoreExternalBonds) matches = compiled.matchResidueToTemplate(residue, template, bondedToAtom, ignoreExternalBonds, ignoreExtraParticles)
if matches is None: if matches is None:
residueMatches = None residueMatches = None
break break
...@@ -2433,12 +2453,12 @@ class LennardJonesGenerator(object): ...@@ -2433,12 +2453,12 @@ class LennardJonesGenerator(object):
def registerNBFIX(self, parameters): def registerNBFIX(self, parameters):
types = self.ff._findAtomTypes(parameters, 2) types = self.ff._findAtomTypes(parameters, 2)
if None not in types: if None not in types:
type1 = types[0][0] for type1 in types[0]:
type2 = types[1][0] for type2 in types[1]:
epsilon = _convertParameterToNumber(parameters['epsilon']) epsilon = _convertParameterToNumber(parameters['epsilon'])
sigma = _convertParameterToNumber(parameters['sigma']) sigma = _convertParameterToNumber(parameters['sigma'])
self.nbfixTypes[(type1, type2)] = [sigma, epsilon] self.nbfixTypes[(type1, type2)] = [sigma, epsilon]
self.nbfixTypes[(type2, type1)] = [sigma, epsilon] self.nbfixTypes[(type2, type1)] = [sigma, epsilon]
def registerLennardJones(self, parameters): def registerLennardJones(self, parameters):
self.ljTypes.registerAtom(parameters) self.ljTypes.registerAtom(parameters)
...@@ -5572,11 +5592,6 @@ class AmoebaUreyBradleyGenerator(object): ...@@ -5572,11 +5592,6 @@ class AmoebaUreyBradleyGenerator(object):
generator.length.append(float(bond.attrib['d'])) generator.length.append(float(bond.attrib['d']))
generator.k.append(float(bond.attrib['k'])) generator.k.append(float(bond.attrib['k']))
else:
outputString = "AmoebaUreyBradleyGenerator: error getting types: %s %s %s" % (
bond.attrib['class1'], bond.attrib['class2'], bond.attrib['class3'])
raise ValueError(outputString)
#============================================================================================= #=============================================================================================
def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
...@@ -5794,6 +5809,19 @@ class DrudeGenerator(object): ...@@ -5794,6 +5809,19 @@ class DrudeGenerator(object):
thole2 = self.typeMap[type2][8] thole2 = self.typeMap[type2][8]
drude.addScreenedPair(drude1, drude2, thole1+thole2) drude.addScreenedPair(drude1, drude2, thole1+thole2)
# Set the masses of Drude particles.
drudeMass = args['drudeMass']
if not unit.is_quantity(drudeMass):
drudeMass *= unit.dalton
for i in range(drude.getNumParticles()):
params = drude.getParticleParameters(i)
particle = params[0]
parent = params[1]
transferMass = drudeMass-sys.getParticleMass(particle)
sys.setParticleMass(particle, drudeMass)
sys.setParticleMass(parent, sys.getParticleMass(parent)-transferMass)
parsers["DrudeForce"] = DrudeGenerator.parseElement parsers["DrudeForce"] = DrudeGenerator.parseElement
#============================================================================================= #=============================================================================================
...@@ -678,7 +678,7 @@ def readAmberSystem(topology, prmtop_filename=None, prmtop_loader=None, shake=No ...@@ -678,7 +678,7 @@ def readAmberSystem(topology, prmtop_filename=None, prmtop_loader=None, shake=No
if shake in ('all-bonds', 'h-angles'): if shake in ('all-bonds', 'h-angles'):
for (iAtom, jAtom, k, rMin) in prmtop.getBondsNoH(): for (iAtom, jAtom, k, rMin) in prmtop.getBondsNoH():
system.addConstraint(iAtom, jAtom, rMin) system.addConstraint(iAtom, jAtom, rMin)
if rigidWater and shake == None: if rigidWater and shake is None:
for (iAtom, jAtom, k, rMin) in prmtop.getBondsWithH(): for (iAtom, jAtom, k, rMin) in prmtop.getBondsWithH():
if isWater[iAtom] and isWater[jAtom]: if isWater[iAtom] and isWater[jAtom]:
system.addConstraint(iAtom, jAtom, rMin) system.addConstraint(iAtom, jAtom, rMin)
......
...@@ -157,6 +157,9 @@ class AtomType(object): ...@@ -157,6 +157,9 @@ class AtomType(object):
return self.number == other return self.number == other
return other == (self.number, self.name) return other == (self.number, self.name)
def __ne__(self, other):
return not self == other
def set_lj_params(self, eps, rmin, eps14=None, rmin14=None): def set_lj_params(self, eps, rmin, eps14=None, rmin14=None):
""" Sets Lennard-Jones parameters on this atom type """ """ Sets Lennard-Jones parameters on this atom type """
if eps14 is None: if eps14 is None:
...@@ -220,6 +223,7 @@ class WildCard(AtomType): ...@@ -220,6 +223,7 @@ class WildCard(AtomType):
# Define comparison operators # Define comparison operators
def __eq__(self, other): return True def __eq__(self, other): return True
def __ne__(self, other): return False
def __lt__(self, other): return True def __lt__(self, other): return True
def __gt__(self, other): return False def __gt__(self, other): return False
def __le__(self, other): return True def __le__(self, other): return True
...@@ -985,6 +989,9 @@ class BondType(object): ...@@ -985,6 +989,9 @@ class BondType(object):
def __eq__(self, other): def __eq__(self, other):
return self.k == other.k and self.req == other.req return self.k == other.k and self.req == other.req
def __ne__(self, other):
return not self == other
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
class AngleType(object): class AngleType(object):
...@@ -1005,6 +1012,9 @@ class AngleType(object): ...@@ -1005,6 +1012,9 @@ class AngleType(object):
def __eq__(self, other): def __eq__(self, other):
return self.k == other.k and self.theteq == other.theteq return self.k == other.k and self.theteq == other.theteq
def __ne__(self, other):
return not self == other
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
class UreyBradleyType(BondType): class UreyBradleyType(BondType):
...@@ -1046,6 +1056,9 @@ class DihedralType(object): ...@@ -1046,6 +1056,9 @@ class DihedralType(object):
return (self.phi_k == other.phi_k and self.per == other.per and return (self.phi_k == other.phi_k and self.per == other.per and
self.phase == other.phase) self.phase == other.phase)
def __ne__(self, other):
return not self == other
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
class ImproperType(object): class ImproperType(object):
...@@ -1066,6 +1079,9 @@ class ImproperType(object): ...@@ -1066,6 +1079,9 @@ class ImproperType(object):
def __eq__(self, other): def __eq__(self, other):
return self.k == other.k and self.phieq == other.phieq return self.k == other.k and self.phieq == other.phieq
def __ne__(self, other):
return not self == other
def __repr__(self): def __repr__(self):
return '<ImproperType; k=%s; phieq=%s>' % (self.k, self.phieq) return '<ImproperType; k=%s; phieq=%s>' % (self.k, self.phieq)
...@@ -1101,6 +1117,9 @@ class CmapType(object): ...@@ -1101,6 +1117,9 @@ class CmapType(object):
return (self.resolution == other.resolution and return (self.resolution == other.resolution and
all([abs(i - j) < TINY for i, j in zip(self.grid, other.grid)])) all([abs(i - j) < TINY for i, j in zip(self.grid, other.grid)]))
def __ne__(self, other):
return not self == other
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Take the CmapGrid class from the Chamber prmtop topology objects # Take the CmapGrid class from the Chamber prmtop topology objects
...@@ -1198,6 +1217,9 @@ class _CmapGrid(object): ...@@ -1198,6 +1217,9 @@ class _CmapGrid(object):
except AttributeError: except AttributeError:
return TypeError('Bad type comparison with _CmapGrid') return TypeError('Bad type comparison with _CmapGrid')
def __ne__(self, other):
return not self == other
def switch_range(self): def switch_range(self):
""" """
Returns a grid object whose range is 0 to 360 degrees in both dimensions Returns a grid object whose range is 0 to 360 degrees in both dimensions
......
...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of ...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org. Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2018 Stanford University and the Authors. Portions copyright (c) 2018-2020 Stanford University and the Authors.
Authors: Peter Eastman Authors: Peter Eastman
Contributors: Contributors:
...@@ -68,7 +68,7 @@ cdef class periodicDistance: ...@@ -68,7 +68,7 @@ cdef class periodicDistance:
return sqrt(dx*dx + dy*dy + dz*dz) return sqrt(dx*dx + dy*dy + dz*dz)
def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds=False): def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds=False, bint ignoreExtraParticles=False):
"""Determine whether a residue matches a template and return a list of corresponding atoms. """Determine whether a residue matches a template and return a list of corresponding atoms.
This is used heavily in ForceField. This is used heavily in ForceField.
...@@ -82,6 +82,8 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds ...@@ -82,6 +82,8 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds
Enumerates which other atoms each atom is bonded to Enumerates which other atoms each atom is bonded to
ignoreExternalBonds : bool ignoreExternalBonds : bool
If true, ignore external bonds when matching templates If true, ignore external bonds when matching templates
ignoreExtraParticles : bool
If true, ignore extra particles (ones whose element is None) when matching templates
Returns Returns
------- -------
...@@ -91,8 +93,18 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds ...@@ -91,8 +93,18 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds
""" """
cdef int numAtoms, i, j cdef int numAtoms, i, j
atoms = list(res.atoms()) atoms = list(res.atoms())
if ignoreExtraParticles:
atoms = [a for a in atoms if a.element is not None]
templateAtoms = [a for a in template.atoms if a.element is not None]
templateBondedTo = {}
for i, atom in enumerate(template.atoms):
if atom.element is not None:
templateBondedTo[atom] = [templateAtoms.index(template.atoms[j]) for j in atom.bondedTo if template.atoms[j].element is not None]
else:
templateAtoms = template.atoms
templateBondedTo = dict((atom, atom.bondedTo) for atom in template.atoms)
numAtoms = len(atoms) numAtoms = len(atoms)
if numAtoms != len(template.atoms): if numAtoms != len(templateAtoms):
return None return None
# Translate from global to local atom indices, and record the bonds for each atom. # Translate from global to local atom indices, and record the bonds for each atom.
...@@ -117,8 +129,8 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds ...@@ -117,8 +129,8 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds
residueTypeCount[key] = 1 residueTypeCount[key] = 1
residueTypeCount[key] += 1 residueTypeCount[key] += 1
templateTypeCount = {} templateTypeCount = {}
for i, atom in enumerate(template.atoms): for i, atom in enumerate(templateAtoms):
key = (atom.element, len(atom.bondedTo), 0 if ignoreExternalBonds else atom.externalBonds) key = (atom.element, len(templateBondedTo[atom]), 0 if ignoreExternalBonds else atom.externalBonds)
if key not in templateTypeCount: if key not in templateTypeCount:
templateTypeCount[key] = 1 templateTypeCount[key] = 1
templateTypeCount[key] += 1 templateTypeCount[key] += 1
...@@ -130,11 +142,11 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds ...@@ -130,11 +142,11 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds
candidates = [[] for i in range(numAtoms)] candidates = [[] for i in range(numAtoms)]
cdef bint exactNameMatch cdef bint exactNameMatch
for i in range(numAtoms): for i in range(numAtoms):
exactNameMatch = (atoms[i].element is None and any(atom.element is None and atom.name == atoms[i].name for atom in template.atoms)) exactNameMatch = (atoms[i].element is None and any(atom.element is None and atom.name == atoms[i].name for atom in templateAtoms))
for j, atom in enumerate(template.atoms): for j, atom in enumerate(templateAtoms):
if (atom.element is not None and atom.element != atoms[i].element) or (exactNameMatch and atom.name != atoms[i].name): if (atom.element is not None and atom.element != atoms[i].element) or (exactNameMatch and atom.name != atoms[i].name):
continue continue
if len(atom.bondedTo) != len(bondedTo[i]): if len(templateBondedTo[atom]) != len(bondedTo[i]):
continue continue
if not ignoreExternalBonds and atom.externalBonds != externalBonds[i]: if not ignoreExternalBonds and atom.externalBonds != externalBonds[i]:
continue continue
...@@ -174,38 +186,38 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds ...@@ -174,38 +186,38 @@ def matchResidueToTemplate(res, template, bondedToAtom, bint ignoreExternalBonds
matches = numAtoms*[0] matches = numAtoms*[0]
hasMatch = numAtoms*[False] hasMatch = numAtoms*[False]
if _findAtomMatches(template, bondedTo, matches, hasMatch, candidates, 0): if _findAtomMatches(templateAtoms, bondedTo, templateBondedTo, matches, hasMatch, candidates, 0):
return [matches[inverseSearchOrder[i]] for i in range(numAtoms)] return [matches[inverseSearchOrder[i]] for i in range(numAtoms)]
return None return None
def _getAtomMatchCandidates(template, bondedTo, matches, candidates, position): def _getAtomMatchCandidates(templateAtoms, bondedTo, templateBondedTo, matches, candidates, position):
"""Get a list of template atoms that are potential matches for the next atom.""" """Get a list of template atoms that are potential matches for the next atom."""
for bonded in bondedTo[position]: for bonded in bondedTo[position]:
if bonded < position: if bonded < position:
# This atom is bonded to another one for which we already have a match, so only consider # This atom is bonded to another one for which we already have a match, so only consider
# template atoms that *that* one is bonded to. # template atoms that *that* one is bonded to.
return template.atoms[matches[bonded]].bondedTo return templateBondedTo[templateAtoms[matches[bonded]]]
return candidates[position] return candidates[position]
def _findAtomMatches(template, bondedTo, matches, hasMatch, candidates, int position): def _findAtomMatches(templateAtoms, bondedTo, templateBondedTo, matches, hasMatch, candidates, int position):
"""This is called recursively from inside matchResidueToTemplate() to identify matching atoms.""" """This is called recursively from inside matchResidueToTemplate() to identify matching atoms."""
if position == len(matches): if position == len(matches):
return True return True
cdef int i cdef int i
for i in _getAtomMatchCandidates(template, bondedTo, matches, candidates, position): for i in _getAtomMatchCandidates(templateAtoms, bondedTo, templateBondedTo, matches, candidates, position):
atom = template.atoms[i] atom = templateAtoms[i]
if not hasMatch[i] and i in candidates[position]: if not hasMatch[i] and i in candidates[position]:
# See if the bonds for this identification are consistent # See if the bonds for this identification are consistent
allBondsMatch = all((bonded > position or matches[bonded] in atom.bondedTo for bonded in bondedTo[position])) allBondsMatch = all((bonded > position or matches[bonded] in templateBondedTo[atom] for bonded in bondedTo[position]))
if allBondsMatch: if allBondsMatch:
# This is a possible match, so try matching the rest of the residue. # This is a possible match, so try matching the rest of the residue.
matches[position] = i matches[position] = i
hasMatch[i] = True hasMatch[i] = True
if _findAtomMatches(template, bondedTo, matches, hasMatch, candidates, position+1): if _findAtomMatches(templateAtoms, bondedTo, templateBondedTo, matches, hasMatch, candidates, position+1):
return True return True
hasMatch[i] = False hasMatch[i] = False
return False return False
...@@ -226,7 +226,7 @@ class PdbStructure(object): ...@@ -226,7 +226,7 @@ class PdbStructure(object):
print("END", file=output_stream) print("END", file=output_stream)
def _add_model(self, model): def _add_model(self, model):
if self.default_model == None: if self.default_model is None:
self.default_model = model self.default_model = model
self.models.append(model) self.models.append(model)
self._current_model = model self._current_model = model
...@@ -292,7 +292,7 @@ class PdbStructure(object): ...@@ -292,7 +292,7 @@ class PdbStructure(object):
def _add_atom(self, atom): def _add_atom(self, atom):
""" """
""" """
if self._current_model == None: if self._current_model is None:
self._add_model(Model(0)) self._add_model(Model(0))
atom.model_number = self._current_model.number atom.model_number = self._current_model.number
# Atom might be alternate position for existing atom # Atom might be alternate position for existing atom
...@@ -560,20 +560,20 @@ class Residue(object): ...@@ -560,20 +560,20 @@ class Residue(object):
def set_name_with_spaces(self, name, alt_loc=None): def set_name_with_spaces(self, name, alt_loc=None):
# Gromacs ffamber PDB files can have 4-character residue names # Gromacs ffamber PDB files can have 4-character residue names
# assert len(name) == 3 # assert len(name) == 3
if alt_loc == None: if alt_loc is None:
alt_loc = self.primary_location_id alt_loc = self.primary_location_id
loc = self.locations[alt_loc] loc = self.locations[alt_loc]
loc.name_with_spaces = name loc.name_with_spaces = name
loc.name = name.strip() loc.name = name.strip()
def get_name_with_spaces(self, alt_loc=None): def get_name_with_spaces(self, alt_loc=None):
if alt_loc == None: if alt_loc is None:
alt_loc = self.primary_location_id alt_loc = self.primary_location_id
loc = self.locations[alt_loc] loc = self.locations[alt_loc]
return loc.name_with_spaces return loc.name_with_spaces
name_with_spaces = property(get_name_with_spaces, set_name_with_spaces, doc='four-character residue name including spaces') name_with_spaces = property(get_name_with_spaces, set_name_with_spaces, doc='four-character residue name including spaces')
def get_name(self, alt_loc=None): def get_name(self, alt_loc=None):
if alt_loc == None: if alt_loc is None:
alt_loc = self.primary_location_id alt_loc = self.primary_location_id
loc = self.locations[alt_loc] loc = self.locations[alt_loc]
return loc.name return loc.name
...@@ -616,7 +616,7 @@ class Residue(object): ...@@ -616,7 +616,7 @@ class Residue(object):
# Three possibilities: primary alt_loc, certain alt_loc, or all alt_locs # Three possibilities: primary alt_loc, certain alt_loc, or all alt_locs
def iter_atoms(self, alt_loc=None): def iter_atoms(self, alt_loc=None):
if alt_loc == None: if alt_loc is None:
locs = [self.primary_location_id] locs = [self.primary_location_id]
elif alt_loc == "": elif alt_loc == "":
locs = [self.primary_location_id] locs = [self.primary_location_id]
...@@ -629,7 +629,7 @@ class Residue(object): ...@@ -629,7 +629,7 @@ class Residue(object):
use_atom = False # start pessimistic use_atom = False # start pessimistic
for loc2 in atom.locations.keys(): for loc2 in atom.locations.keys():
# print "#%s#%s" % (loc2,locs) # print "#%s#%s" % (loc2,locs)
if locs == None: # means all locations if locs is None: # means all locations
use_atom = True use_atom = True
elif loc2 in locs: elif loc2 in locs:
use_atom = True use_atom = True
...@@ -805,7 +805,7 @@ class Atom(object): ...@@ -805,7 +805,7 @@ class Atom(object):
try: try:
# Try to find a sensible element symbol from columns 76-77 # Try to find a sensible element symbol from columns 76-77
self.element = element.get_by_symbol(self.element_symbol) self.element = element.get_by_symbol(self.element_symbol)
except KeyError: except KeyError:
self.element = None self.element = None
if pdbstructure is not None: if pdbstructure is not None:
pdbstructure._next_atom_number = self.serial_number+1 pdbstructure._next_atom_number = self.serial_number+1
...@@ -850,12 +850,12 @@ class Atom(object): ...@@ -850,12 +850,12 @@ class Atom(object):
# Hide existence of multiple alternate locations to avoid scaring casual users # Hide existence of multiple alternate locations to avoid scaring casual users
def get_location(self, location_id=None): def get_location(self, location_id=None):
id = location_id id = location_id
if (id == None): if id is None:
id = self.default_location_id id = self.default_location_id
return self.locations[id] return self.locations[id]
def set_location(self, new_location, location_id=None): def set_location(self, new_location, location_id=None):
id = location_id id = location_id
if (id == None): if id is None:
id = self.default_location_id id = self.default_location_id
self.locations[id] = new_location self.locations[id] = new_location
location = property(get_location, set_location, doc='default Atom.Location object') location = property(get_location, set_location, doc='default Atom.Location object')
...@@ -891,9 +891,9 @@ class Atom(object): ...@@ -891,9 +891,9 @@ class Atom(object):
""" """
Produce a PDB line for this atom using a particular serial number and alternate location Produce a PDB line for this atom using a particular serial number and alternate location
""" """
if serial_number == None: if serial_number is None:
serial_number = self.serial_number serial_number = self.serial_number
if alternate_location_indicator == None: if alternate_location_indicator is None:
alternate_location_indicator = self.alternate_location_indicator alternate_location_indicator = self.alternate_location_indicator
# produce PDB line in three parts: names, numbers, and end # produce PDB line in three parts: names, numbers, and end
# Accomodate 4-character residue names that use column 21 # Accomodate 4-character residue names that use column 21
...@@ -927,7 +927,7 @@ class Atom(object): ...@@ -927,7 +927,7 @@ class Atom(object):
alt_loc = None means write just the primary location alt_loc = None means write just the primary location
alt_loc = "AB" means write locations "A" and "B" alt_loc = "AB" means write locations "A" and "B"
""" """
if alt_loc == None: if alt_loc is None:
locs = [self.default_location_id] locs = [self.default_location_id]
elif alt_loc == "": elif alt_loc == "":
locs = [self.default_location_id] locs = [self.default_location_id]
......
...@@ -1028,44 +1028,14 @@ class Modeller(object): ...@@ -1028,44 +1028,14 @@ class Modeller(object):
This is useful when the Topology represents one piece of a larger This is useful when the Topology represents one piece of a larger
molecule, so chains are not terminated properly. molecule, so chains are not terminated properly.
""" """
# Create copies of all residue templates that have had all extra points removed. # Record which atoms are bonded to each other atom.
templatesNoEP = {}
for resName, template in forcefield._templates.items():
if any(atom.element is None for atom in template.atoms):
index = 0
newIndex = {}
newTemplate = ForceField._TemplateData(resName)
for i, atom in enumerate(template.atoms):
if atom.element is not None:
newIndex[i] = index
index += 1
newAtom = ForceField._TemplateAtomData(atom.name, atom.type, atom.element)
newAtom.externalBonds = atom.externalBonds
newTemplate.atoms.append(newAtom)
for b1, b2 in template.bonds:
if b1 in newIndex and b2 in newIndex:
newTemplate.bonds.append((newIndex[b1], newIndex[b2]))
newTemplate.atoms[newIndex[b1]].bondedTo.append(newIndex[b2])
newTemplate.atoms[newIndex[b2]].bondedTo.append(newIndex[b1])
for b in template.externalBonds:
if b in newIndex:
newTemplate.externalBonds.append(newIndex[b])
templatesNoEP[template] = newTemplate
# Record which atoms are bonded to each other atom, with and without extra particles.
bondedToAtom = [] bondedToAtom = []
bondedToAtomNoEP = []
for atom in self.topology.atoms(): for atom in self.topology.atoms():
bondedToAtom.append(set()) bondedToAtom.append(set())
bondedToAtomNoEP.append(set())
for atom1, atom2 in self.topology.bonds(): for atom1, atom2 in self.topology.bonds():
bondedToAtom[atom1.index].add(atom2.index) bondedToAtom[atom1.index].add(atom2.index)
bondedToAtom[atom2.index].add(atom1.index) bondedToAtom[atom2.index].add(atom1.index)
if atom1.element is not None and atom2.element is not None:
bondedToAtomNoEP[atom1.index].add(atom2.index)
bondedToAtomNoEP[atom2.index].add(atom1.index)
# If the force field has a DrudeForce, record the types of Drude particles and their parents since we'll # If the force field has a DrudeForce, record the types of Drude particles and their parents since we'll
# need them for picking particle positions. # need them for picking particle positions.
...@@ -1076,6 +1046,10 @@ class Modeller(object): ...@@ -1076,6 +1046,10 @@ class Modeller(object):
for type in force.typeMap: for type in force.typeMap:
drudeTypeMap[type] = force.typeMap[type][0] drudeTypeMap[type] = force.typeMap[type][0]
# Identify the template to use for each residue.
templates = forcefield._matchAllResiduesToTemplates(ForceField._SystemData(self.topology), self.topology, {}, False, True, False)
# Create the new Topology. # Create the new Topology.
newTopology = Topology() newTopology = Topology()
...@@ -1087,16 +1061,8 @@ class Modeller(object): ...@@ -1087,16 +1061,8 @@ class Modeller(object):
newChain = newTopology.addChain(chain.id) newChain = newTopology.addChain(chain.id)
for residue in chain.residues(): for residue in chain.residues():
newResidue = newTopology.addResidue(residue.name, newChain, residue.id, residue.insertionCode) newResidue = newTopology.addResidue(residue.name, newChain, residue.id, residue.insertionCode)
template = templates[residue.index]
# Look for a matching template. if len(template.atoms) == len(list(residue.atoms())):
matchFound = False
signature = _createResidueSignature([atom.element for atom in residue.atoms()])
if signature in forcefield._templateSignatures:
for t in forcefield._templateSignatures[signature]:
if compiled.matchResidueToTemplate(residue, t, bondedToAtom, ignoreExternalBonds) is not None:
matchFound = True
if matchFound:
# Just copy the residue over. # Just copy the residue over.
for atom in residue.atoms(): for atom in residue.atoms():
...@@ -1104,28 +1070,17 @@ class Modeller(object): ...@@ -1104,28 +1070,17 @@ class Modeller(object):
newAtoms[atom] = newAtom newAtoms[atom] = newAtom
newPositions.append(deepcopy(self.positions[atom.index])) newPositions.append(deepcopy(self.positions[atom.index]))
else: else:
# There's no matching template. Try to find one that matches based on everything except # Record the corresponding atoms.
# extra points.
matches = compiled.matchResidueToTemplate(residue, template, bondedToAtom, ignoreExternalBonds, True)
template = None atomsNoEP = [a for a in residue.atoms() if a.element is not None]
residueNoEP = Residue(residue.name, residue.index, residue.chain, residue.id, residue.insertionCode) templateAtomsNoEP = [a for a in template.atoms if a.element is not None]
residueNoEP._atoms = [atom for atom in residue.atoms() if atom.element is not None] matchingAtoms = {}
if signature in forcefield._templateSignatures: for atom, match in zip(atomsNoEP, matches):
for t in forcefield._templateSignatures[signature]: templateAtomName = templateAtomsNoEP[match].name
if t in templatesNoEP: for templateAtom in template.atoms:
matches = compiled.matchResidueToTemplate(residueNoEP, templatesNoEP[t], bondedToAtomNoEP, ignoreExternalBonds) if templateAtom.name == templateAtomName:
if matches is not None: matchingAtoms[templateAtom] = atom
template = t;
# Record the corresponding atoms.
matchingAtoms = {}
for atom, match in zip(residueNoEP.atoms(), matches):
templateAtomName = templatesNoEP[t].atoms[match].name
for templateAtom in template.atoms:
if templateAtom.name == templateAtomName:
matchingAtoms[templateAtom] = atom
break
if template is None:
raise ValueError('Residue %d (%s) does not match any template defined by the ForceField.' % (residue.index+1, residue.name))
# Add the regular atoms. # Add the regular atoms.
......
...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of ...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org. Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2012-2013 Stanford University and the Authors. Portions copyright (c) 2012-2020 Stanford University and the Authors.
Authors: Peter Eastman Authors: Peter Eastman
Contributors: Robert McGibbon Contributors: Robert McGibbon
...@@ -289,7 +289,10 @@ class StateDataReporter(object): ...@@ -289,7 +289,10 @@ class StateDataReporter(object):
for i in range(system.getNumParticles()): for i in range(system.getNumParticles()):
if system.getParticleMass(i) > 0*unit.dalton: if system.getParticleMass(i) > 0*unit.dalton:
dof += 3 dof += 3
dof -= system.getNumConstraints() for i in range(system.getNumConstraints()):
p1, p2, distance = system.getConstraintParameters(i)
if system.getParticleMass(p1) > 0*unit.dalton or system.getParticleMass(p2) > 0*unit.dalton:
dof -= 1
if any(type(system.getForce(i)) == mm.CMMotionRemover for i in range(system.getNumForces())): if any(type(system.getForce(i)) == mm.CMMotionRemover for i in range(system.getNumForces())):
dof -= 3 dof -= 3
self._dof = dof self._dof = dof
......
...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of ...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org. Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2013-2015 Stanford University and the Authors. Portions copyright (c) 2013-2020 Stanford University and the Authors.
Authors: Peter Eastman Authors: Peter Eastman
Contributors: Contributors:
...@@ -107,3 +107,93 @@ class MTSIntegrator(CustomIntegrator): ...@@ -107,3 +107,93 @@ class MTSIntegrator(CustomIntegrator):
else: else:
self._createSubsteps(substeps, groups[1:]) self._createSubsteps(substeps, groups[1:])
self.addComputePerDof("v", "v+0.5*(dt/"+str(substeps)+")*f"+str(group)+"/m") self.addComputePerDof("v", "v+0.5*(dt/"+str(substeps)+")*f"+str(group)+"/m")
class MTSLangevinIntegrator(CustomIntegrator):
"""MTSLangevinIntegrator implements the BAOAB-RESPA multiple time step algorithm for
constant temperature dynamics.
This integrator allows different forces to be evaluated at different frequencies,
for example to evaluate the expensive, slowly changing forces less frequently than
the inexpensive, quickly changing forces.
To use it, you must first divide your forces into two or more groups (by calling
setForceGroup() on them) that should be evaluated at different frequencies. When
you create the integrator, you provide a tuple for each group specifying the index
of the force group and the frequency (as a fraction of the outermost time step) at
which to evaluate it. For example::
integrator = MTSLangevinIntegrator(300*kelvin, 1/picosecond, 4*femtoseconds, [(0,1), (1,2), (2,8)])
This specifies that the outermost time step is 4 fs, so each step of the integrator
will advance time by that much. It also says that force group 0 should be evaluated
once per time step, force group 1 should be evaluated twice per time step (every 2 fs),
and force group 2 should be evaluated eight times per time step (every 0.5 fs).
A common use of this algorithm is to evaluate reciprocal space nonbonded interactions
less often than the bonded and direct space nonbonded interactions. The following
example looks up the NonbondedForce, sets the reciprocal space interactions to their
own force group, and then creates an integrator that evaluates them once every 4 fs,
but all other interactions every 2 fs::
nonbonded = [f for f in system.getForces() if isinstance(f, NonbondedForce)][0]
nonbonded.setReciprocalSpaceForceGroup(1)
integrator = MTSLangevinIntegrator(300*kelvin, 1/picosecond, 4*femtoseconds, [(1,1), (0,2)])
For details, see Tuckerman et al., J. Chem. Phys. 97(3) pp. 1990-2001 (1992) and
Lagardere et al., J. Phys. Chem. Lett. 10(10) pp. 2593-2599 (2019).
"""
def __init__(self, temperature, friction, dt, groups):
"""Create an MTSLangevinIntegrator.
Parameters
----------
temperature : temperature
the temperature of the heat bath
friction : 1/temperature
the friction coefficient which couples the system to the heat bath
dt : time
The largest (outermost) integration time step to use
groups : list
A list of tuples defining the force groups. The first element of
each tuple is the force group index, and the second element is the
number of times that force group should be evaluated in one time step.
"""
if len(groups) == 0:
raise ValueError("No force groups specified")
groups = sorted(groups, key=lambda x: x[1])
CustomIntegrator.__init__(self, dt)
self.temperature = temperature
self.friction = friction
import math
self.addGlobalVariable("a", math.exp(-friction*dt))
self.addGlobalVariable("b", math.sqrt(1-math.exp(-2*friction*dt)))
from simtk.unit import MOLAR_GAS_CONSTANT_R
self.addGlobalVariable('kT', MOLAR_GAS_CONSTANT_R*temperature)
self.addPerDofVariable("x1", 0)
self.addUpdateContextState();
self._createSubsteps(1, groups)
self.addConstrainVelocities();
def _createSubsteps(self, parentSubsteps, groups):
group, substeps = groups[0]
stepsPerParentStep = substeps / parentSubsteps
if stepsPerParentStep < 1 or stepsPerParentStep != int(stepsPerParentStep):
raise ValueError("The number for substeps for each group must be a multiple of the number for the previous group")
stepsPerParentStep = int(stepsPerParentStep)
if group < 0 or group > 31:
raise ValueError("Force group must be between 0 and 31")
for i in range(stepsPerParentStep):
self.addComputePerDof("v", "v+0.5*(dt/"+str(substeps)+")*f"+str(group)+"/m")
if len(groups) == 1:
self.addComputePerDof("x", "x+(dt/"+str(2*substeps)+")*v")
self.addComputePerDof("v", "a*v + b*sqrt(kT/m)*gaussian")
self.addComputePerDof("x", "x+(dt/"+str(2*substeps)+")*v")
self.addComputePerDof("x1", "x")
self.addConstrainPositions();
self.addComputePerDof("v", "v+(x-x1)/(dt/"+str(substeps)+")");
self.addConstrainVelocities()
else:
self._createSubsteps(substeps, groups[1:])
self.addComputePerDof("v", "v+0.5*(dt/"+str(substeps)+")*f"+str(group)+"/m")
...@@ -267,7 +267,7 @@ class Quantity(object): ...@@ -267,7 +267,7 @@ class Quantity(object):
def __ne__(self, other): def __ne__(self, other):
""" """
""" """
return not self.__eq__(other) return not self == other
def __lt__(self, other): def __lt__(self, other):
"""Compares two quantities. """Compares two quantities.
......
...@@ -179,7 +179,7 @@ class Unit(object): ...@@ -179,7 +179,7 @@ class Unit(object):
return self.get_name() == other.get_name() return self.get_name() == other.get_name()
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self == other
def __lt__(self, other): def __lt__(self, other):
"""Compare two Units. """Compare two Units.
......
...@@ -18,6 +18,7 @@ SKIP_METHODS = [('State', 'getPositions'), ...@@ -18,6 +18,7 @@ SKIP_METHODS = [('State', 'getPositions'),
('State', 'getForces'), ('State', 'getForces'),
('StateBuilder',), ('StateBuilder',),
('Vec3',), ('Vec3',),
('OpenMMException',),
('AngleInfo',), ('AngleInfo',),
('ApplyAndersenThermostatKernel',), ('ApplyAndersenThermostatKernel',),
('ApplyConstraintsKernel',), ('ApplyConstraintsKernel',),
...@@ -63,11 +64,10 @@ SKIP_METHODS = [('State', 'getPositions'), ...@@ -63,11 +64,10 @@ SKIP_METHODS = [('State', 'getPositions'),
('InitializeForcesKernel',), ('InitializeForcesKernel',),
('IntegrateBrownianStepKernel',), ('IntegrateBrownianStepKernel',),
('IntegrateLangevinStepKernel',), ('IntegrateLangevinStepKernel',),
('IntegrateNoseHooverStepKernel',),
('IntegrateVariableLangevinStepKernel',), ('IntegrateVariableLangevinStepKernel',),
('IntegrateVariableVerletStepKernel',), ('IntegrateVariableVerletStepKernel',),
('IntegrateVerletStepKernel',), ('IntegrateVerletStepKernel',),
('IntegrateVelocityVerletStepKernel',),
('NoseHooverChainKernel',),
('IntegrateCustomStepKernel',), ('IntegrateCustomStepKernel',),
('Kernel',), ('Kernel',),
('KernelFactory',), ('KernelFactory',),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment