Commit df2b723d authored by leeping's avatar leeping
Browse files

Merge branch 'master' of github.com:SimTk/openmm

parents a0f16cc0 7be6e8fb
...@@ -87,6 +87,7 @@ vector<string> RPMDIntegrator::getKernelNames() { ...@@ -87,6 +87,7 @@ vector<string> RPMDIntegrator::getKernelNames() {
void RPMDIntegrator::setPositions(int copy, const vector<Vec3>& positions) { void RPMDIntegrator::setPositions(int copy, const vector<Vec3>& positions) {
kernel.getAs<IntegrateRPMDStepKernel>().setPositions(copy, positions); kernel.getAs<IntegrateRPMDStepKernel>().setPositions(copy, positions);
forcesAreValid = false;
hasSetPosition = true; hasSetPosition = true;
} }
...@@ -168,6 +169,8 @@ double RPMDIntegrator::computeKineticEnergy() { ...@@ -168,6 +169,8 @@ double RPMDIntegrator::computeKineticEnergy() {
} }
void RPMDIntegrator::step(int steps) { void RPMDIntegrator::step(int steps) {
if (context == NULL)
throw OpenMMException("This Integrator is not bound to a context!");
if (!hasSetPosition) { if (!hasSetPosition) {
// Initialize the positions from the context. // Initialize the positions from the context.
......
...@@ -23,13 +23,16 @@ extern "C" __global__ void contractPositions(mixed4* posq, mixed4* contracted) { ...@@ -23,13 +23,16 @@ extern "C" __global__ void contractPositions(mixed4* posq, mixed4* contracted) {
const int indexInBlock = threadIdx.x-blockStart; const int indexInBlock = threadIdx.x-blockStart;
__shared__ mixed3 q[2*THREAD_BLOCK_SIZE]; __shared__ mixed3 q[2*THREAD_BLOCK_SIZE];
__shared__ mixed3 temp[2*THREAD_BLOCK_SIZE]; __shared__ mixed3 temp[2*THREAD_BLOCK_SIZE];
__shared__ mixed2 w[NUM_COPIES]; __shared__ mixed2 w1[NUM_COPIES];
__shared__ mixed2 w2[NUM_CONTRACTED_COPIES];
mixed3* qreal = &q[blockStart]; mixed3* qreal = &q[blockStart];
mixed3* qimag = &q[blockStart+blockDim.x]; mixed3* qimag = &q[blockStart+blockDim.x];
mixed3* tempreal = &temp[blockStart]; mixed3* tempreal = &temp[blockStart];
mixed3* tempimag = &temp[blockStart+blockDim.x]; mixed3* tempimag = &temp[blockStart+blockDim.x];
if (threadIdx.x < NUM_COPIES) if (threadIdx.x < NUM_COPIES)
w[indexInBlock] = make_mixed2(cos(-indexInBlock*2*M_PI/NUM_COPIES), sin(-indexInBlock*2*M_PI/NUM_COPIES)); w1[indexInBlock] = make_mixed2(cos(-indexInBlock*2*M_PI/NUM_COPIES), sin(-indexInBlock*2*M_PI/NUM_COPIES));
if (threadIdx.x < NUM_CONTRACTED_COPIES)
w2[indexInBlock] = make_mixed2(cos(-indexInBlock*2*M_PI/NUM_CONTRACTED_COPIES), sin(-indexInBlock*2*M_PI/NUM_CONTRACTED_COPIES));
__syncthreads(); __syncthreads();
for (int particle = (blockIdx.x*blockDim.x+threadIdx.x)/NUM_COPIES; particle < NUM_ATOMS; particle += numBlocks) { for (int particle = (blockIdx.x*blockDim.x+threadIdx.x)/NUM_COPIES; particle < NUM_ATOMS; particle += numBlocks) {
// Load the particle position. // Load the particle position.
...@@ -41,6 +44,7 @@ extern "C" __global__ void contractPositions(mixed4* posq, mixed4* contracted) { ...@@ -41,6 +44,7 @@ extern "C" __global__ void contractPositions(mixed4* posq, mixed4* contracted) {
// Forward FFT. // Forward FFT.
__syncthreads(); __syncthreads();
mixed2* w = w1;
FFT_Q_FORWARD FFT_Q_FORWARD
if (NUM_CONTRACTED_COPIES > 1) { if (NUM_CONTRACTED_COPIES > 1) {
// Compress the data to remove high frequencies. // Compress the data to remove high frequencies.
...@@ -54,6 +58,7 @@ extern "C" __global__ void contractPositions(mixed4* posq, mixed4* contracted) { ...@@ -54,6 +58,7 @@ extern "C" __global__ void contractPositions(mixed4* posq, mixed4* contracted) {
qimag[indexInBlock] = tempimag[indexInBlock < start ? indexInBlock : indexInBlock+(NUM_COPIES-NUM_CONTRACTED_COPIES)]; qimag[indexInBlock] = tempimag[indexInBlock < start ? indexInBlock : indexInBlock+(NUM_COPIES-NUM_CONTRACTED_COPIES)];
} }
__syncthreads(); __syncthreads();
w = w2;
FFT_Q_BACKWARD FFT_Q_BACKWARD
} }
...@@ -74,13 +79,16 @@ extern "C" __global__ void contractForces(long long* force, long long* contracte ...@@ -74,13 +79,16 @@ extern "C" __global__ void contractForces(long long* force, long long* contracte
const mixed forceScale = 1/(mixed) 0x100000000; const mixed forceScale = 1/(mixed) 0x100000000;
__shared__ mixed3 f[2*THREAD_BLOCK_SIZE]; __shared__ mixed3 f[2*THREAD_BLOCK_SIZE];
__shared__ mixed3 temp[2*THREAD_BLOCK_SIZE]; __shared__ mixed3 temp[2*THREAD_BLOCK_SIZE];
__shared__ mixed2 w[NUM_COPIES]; __shared__ mixed2 w1[NUM_COPIES];
__shared__ mixed2 w2[NUM_CONTRACTED_COPIES];
mixed3* freal = &f[blockStart]; mixed3* freal = &f[blockStart];
mixed3* fimag = &f[blockStart+blockDim.x]; mixed3* fimag = &f[blockStart+blockDim.x];
mixed3* tempreal = &temp[blockStart]; mixed3* tempreal = &temp[blockStart];
mixed3* tempimag = &temp[blockStart+blockDim.x]; mixed3* tempimag = &temp[blockStart+blockDim.x];
if (threadIdx.x < NUM_COPIES) if (threadIdx.x < NUM_COPIES)
w[indexInBlock] = make_mixed2(cos(-indexInBlock*2*M_PI/NUM_COPIES), sin(-indexInBlock*2*M_PI/NUM_COPIES)); w1[indexInBlock] = make_mixed2(cos(-indexInBlock*2*M_PI/NUM_COPIES), sin(-indexInBlock*2*M_PI/NUM_COPIES));
if (threadIdx.x < NUM_CONTRACTED_COPIES)
w2[indexInBlock] = make_mixed2(cos(-indexInBlock*2*M_PI/NUM_CONTRACTED_COPIES), sin(-indexInBlock*2*M_PI/NUM_CONTRACTED_COPIES));
__syncthreads(); __syncthreads();
for (int particle = (blockIdx.x*blockDim.x+threadIdx.x)/NUM_COPIES; particle < NUM_ATOMS; particle += numBlocks) { for (int particle = (blockIdx.x*blockDim.x+threadIdx.x)/NUM_COPIES; particle < NUM_ATOMS; particle += numBlocks) {
// Load the force. // Load the force.
...@@ -94,6 +102,7 @@ extern "C" __global__ void contractForces(long long* force, long long* contracte ...@@ -94,6 +102,7 @@ extern "C" __global__ void contractForces(long long* force, long long* contracte
// Forward FFT. // Forward FFT.
mixed2* w = w2;
if (NUM_CONTRACTED_COPIES > 1) { if (NUM_CONTRACTED_COPIES > 1) {
FFT_F_FORWARD FFT_F_FORWARD
} }
...@@ -110,6 +119,7 @@ extern "C" __global__ void contractForces(long long* force, long long* contracte ...@@ -110,6 +119,7 @@ extern "C" __global__ void contractForces(long long* force, long long* contracte
fimag[indexInBlock] = (indexInBlock < end ? make_mixed3(0) : tempimag[indexInBlock-(NUM_COPIES-NUM_CONTRACTED_COPIES)]); fimag[indexInBlock] = (indexInBlock < end ? make_mixed3(0) : tempimag[indexInBlock-(NUM_COPIES-NUM_CONTRACTED_COPIES)]);
} }
__syncthreads(); __syncthreads();
w = w1;
FFT_F_BACKWARD FFT_F_BACKWARD
// Store results. // Store results.
......
...@@ -271,7 +271,7 @@ void testCMMotionRemoval() { ...@@ -271,7 +271,7 @@ void testCMMotionRemoval() {
pos += calcCM(state.getPositions(), system); pos += calcCM(state.getPositions(), system);
} }
pos *= 1.0/numCopies; pos *= 1.0/numCopies;
ASSERT_EQUAL_VEC(Vec3(), pos, 0.5); ASSERT_EQUAL_VEC(Vec3(0,0,0), pos, 0.5);
} }
} }
......
...@@ -23,13 +23,16 @@ __kernel void contractPositions(__global mixed4* posq, __global mixed4* contract ...@@ -23,13 +23,16 @@ __kernel void contractPositions(__global mixed4* posq, __global mixed4* contract
const int indexInBlock = get_local_id(0)-blockStart; const int indexInBlock = get_local_id(0)-blockStart;
__local mixed4 q[2*THREAD_BLOCK_SIZE]; __local mixed4 q[2*THREAD_BLOCK_SIZE];
__local mixed4 temp[2*THREAD_BLOCK_SIZE]; __local mixed4 temp[2*THREAD_BLOCK_SIZE];
__local mixed2 w[NUM_COPIES]; __local mixed2 w1[NUM_COPIES];
__local mixed2 w2[NUM_CONTRACTED_COPIES];
__local mixed4* qreal = &q[blockStart]; __local mixed4* qreal = &q[blockStart];
__local mixed4* qimag = &q[blockStart+get_local_size(0)]; __local mixed4* qimag = &q[blockStart+get_local_size(0)];
__local mixed4* tempreal = &temp[blockStart]; __local mixed4* tempreal = &temp[blockStart];
__local mixed4* tempimag = &temp[blockStart+get_local_size(0)]; __local mixed4* tempimag = &temp[blockStart+get_local_size(0)];
if (get_local_id(0) < NUM_COPIES) if (get_local_id(0) < NUM_COPIES)
w[indexInBlock] = (mixed2) (cos(-indexInBlock*2*M_PI/NUM_COPIES), sin(-indexInBlock*2*M_PI/NUM_COPIES)); w1[indexInBlock] = (mixed2) (cos(-indexInBlock*2*M_PI/NUM_COPIES), sin(-indexInBlock*2*M_PI/NUM_COPIES));
if (get_local_id(0) < NUM_CONTRACTED_COPIES)
w2[indexInBlock] = (mixed2) (cos(-indexInBlock*2*M_PI/NUM_CONTRACTED_COPIES), sin(-indexInBlock*2*M_PI/NUM_CONTRACTED_COPIES));
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
for (int particle = get_global_id(0)/NUM_COPIES; particle < NUM_ATOMS; particle += numBlocks) { for (int particle = get_global_id(0)/NUM_COPIES; particle < NUM_ATOMS; particle += numBlocks) {
// Load the particle position. // Load the particle position.
...@@ -41,6 +44,7 @@ __kernel void contractPositions(__global mixed4* posq, __global mixed4* contract ...@@ -41,6 +44,7 @@ __kernel void contractPositions(__global mixed4* posq, __global mixed4* contract
// Forward FFT. // Forward FFT.
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
__local mixed2* w = w1;
FFT_Q_FORWARD FFT_Q_FORWARD
if (NUM_CONTRACTED_COPIES > 1) { if (NUM_CONTRACTED_COPIES > 1) {
// Compress the data to remove high frequencies. // Compress the data to remove high frequencies.
...@@ -54,6 +58,7 @@ __kernel void contractPositions(__global mixed4* posq, __global mixed4* contract ...@@ -54,6 +58,7 @@ __kernel void contractPositions(__global mixed4* posq, __global mixed4* contract
qimag[indexInBlock] = tempimag[indexInBlock < start ? indexInBlock : indexInBlock+(NUM_COPIES-NUM_CONTRACTED_COPIES)]; qimag[indexInBlock] = tempimag[indexInBlock < start ? indexInBlock : indexInBlock+(NUM_COPIES-NUM_CONTRACTED_COPIES)];
} }
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
w = w2;
FFT_Q_BACKWARD FFT_Q_BACKWARD
} }
...@@ -73,13 +78,16 @@ __kernel void contractForces(__global real4* force, __global real4* contracted) ...@@ -73,13 +78,16 @@ __kernel void contractForces(__global real4* force, __global real4* contracted)
const int indexInBlock = get_local_id(0)-blockStart; const int indexInBlock = get_local_id(0)-blockStart;
__local mixed4 f[2*THREAD_BLOCK_SIZE]; __local mixed4 f[2*THREAD_BLOCK_SIZE];
__local mixed4 temp[2*THREAD_BLOCK_SIZE]; __local mixed4 temp[2*THREAD_BLOCK_SIZE];
__local mixed2 w[NUM_COPIES]; __local mixed2 w1[NUM_COPIES];
__local mixed2 w2[NUM_CONTRACTED_COPIES];
__local mixed4* freal = &f[blockStart]; __local mixed4* freal = &f[blockStart];
__local mixed4* fimag = &f[blockStart+get_local_size(0)]; __local mixed4* fimag = &f[blockStart+get_local_size(0)];
__local mixed4* tempreal = &temp[blockStart]; __local mixed4* tempreal = &temp[blockStart];
__local mixed4* tempimag = &temp[blockStart+get_local_size(0)]; __local mixed4* tempimag = &temp[blockStart+get_local_size(0)];
if (get_local_id(0) < NUM_COPIES) if (get_local_id(0) < NUM_COPIES)
w[indexInBlock] = (mixed2) (cos(-indexInBlock*2*M_PI/NUM_COPIES), sin(-indexInBlock*2*M_PI/NUM_COPIES)); w1[indexInBlock] = (mixed2) (cos(-indexInBlock*2*M_PI/NUM_COPIES), sin(-indexInBlock*2*M_PI/NUM_COPIES));
if (get_local_id(0) < NUM_CONTRACTED_COPIES)
w2[indexInBlock] = (mixed2) (cos(-indexInBlock*2*M_PI/NUM_CONTRACTED_COPIES), sin(-indexInBlock*2*M_PI/NUM_CONTRACTED_COPIES));
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
for (int particle = get_global_id(0)/NUM_COPIES; particle < NUM_ATOMS; particle += numBlocks) { for (int particle = get_global_id(0)/NUM_COPIES; particle < NUM_ATOMS; particle += numBlocks) {
// Load the force. // Load the force.
...@@ -93,6 +101,7 @@ __kernel void contractForces(__global real4* force, __global real4* contracted) ...@@ -93,6 +101,7 @@ __kernel void contractForces(__global real4* force, __global real4* contracted)
// Forward FFT. // Forward FFT.
__local mixed2* w = w2;
if (NUM_CONTRACTED_COPIES > 1) { if (NUM_CONTRACTED_COPIES > 1) {
FFT_F_FORWARD FFT_F_FORWARD
} }
...@@ -109,6 +118,7 @@ __kernel void contractForces(__global real4* force, __global real4* contracted) ...@@ -109,6 +118,7 @@ __kernel void contractForces(__global real4* force, __global real4* contracted)
fimag[indexInBlock] = (indexInBlock < end ? (mixed4) (0.0f, 0.0f, 0.0f, 0.0f) : tempimag[indexInBlock-(NUM_COPIES-NUM_CONTRACTED_COPIES)]); fimag[indexInBlock] = (indexInBlock < end ? (mixed4) (0.0f, 0.0f, 0.0f, 0.0f) : tempimag[indexInBlock-(NUM_COPIES-NUM_CONTRACTED_COPIES)]);
} }
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
w = w1;
FFT_F_BACKWARD FFT_F_BACKWARD
// Store results. // Store results.
......
...@@ -272,7 +272,7 @@ void testCMMotionRemoval() { ...@@ -272,7 +272,7 @@ void testCMMotionRemoval() {
pos += calcCM(state.getPositions(), system); pos += calcCM(state.getPositions(), system);
} }
pos *= 1.0/numCopies; pos *= 1.0/numCopies;
ASSERT_EQUAL_VEC(Vec3(), pos, 0.5); ASSERT_EQUAL_VEC(Vec3(0,0,0), pos, 0.5);
} }
} }
......
...@@ -155,7 +155,7 @@ void testCMMotionRemoval() { ...@@ -155,7 +155,7 @@ void testCMMotionRemoval() {
pos += calcCM(state.getPositions(), system); pos += calcCM(state.getPositions(), system);
} }
pos *= 1.0/numCopies; pos *= 1.0/numCopies;
ASSERT_EQUAL_VEC(Vec3(), pos, 0.5); ASSERT_EQUAL_VEC(Vec3(0,0,0), pos, 0.5);
} }
} }
......
...@@ -68,7 +68,7 @@ class WrapperGenerator: ...@@ -68,7 +68,7 @@ class WrapperGenerator:
def __init__(self, inputDirname, output): def __init__(self, inputDirname, output):
self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy'] self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy']
self.skipMethods = ['OpenMM::Context::getState', 'OpenMM::Platform::loadPluginsFromDirectory', 'OpenMM::Context::createCheckpoint', 'OpenMM::Context::loadCheckpoint', 'OpenMM::Context::getMolecules'] self.skipMethods = ['OpenMM::Context::getState', 'OpenMM::Platform::loadPluginsFromDirectory', 'OpenMM::Platform::getPluginLoadFailures', 'OpenMM::Context::createCheckpoint', 'OpenMM::Context::loadCheckpoint', 'OpenMM::Context::getMolecules']
self.hideClasses = ['Kernel', 'KernelImpl', 'KernelFactory', 'ContextImpl', 'SerializationNode', 'SerializationProxy'] self.hideClasses = ['Kernel', 'KernelImpl', 'KernelFactory', 'ContextImpl', 'SerializationNode', 'SerializationProxy']
self.nodeByID={} self.nodeByID={}
...@@ -398,6 +398,7 @@ extern OPENMM_EXPORT void %(name)s_insert(%(name)s* set, %(type)s value);""" % v ...@@ -398,6 +398,7 @@ extern OPENMM_EXPORT void %(name)s_insert(%(name)s* set, %(type)s value);""" % v
Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself. */ Unlike the C++ versions, the return value is allocated on the heap, and you must delete it yourself. */
extern OPENMM_EXPORT OpenMM_State* OpenMM_Context_getState(const OpenMM_Context* target, int types, int enforcePeriodicBox); extern OPENMM_EXPORT OpenMM_State* OpenMM_Context_getState(const OpenMM_Context* target, int types, int enforcePeriodicBox);
extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirectory(const char* directory); extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirectory(const char* directory);
extern OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_getPluginLoadFailures();
extern OPENMM_EXPORT char* OpenMM_XmlSerializer_serializeSystem(const OpenMM_System* system); extern OPENMM_EXPORT char* OpenMM_XmlSerializer_serializeSystem(const OpenMM_System* system);
extern OPENMM_EXPORT char* OpenMM_XmlSerializer_serializeState(const OpenMM_State* state); extern OPENMM_EXPORT char* OpenMM_XmlSerializer_serializeState(const OpenMM_State* state);
extern OPENMM_EXPORT char* OpenMM_XmlSerializer_serializeIntegrator(const OpenMM_Integrator* integrator); extern OPENMM_EXPORT char* OpenMM_XmlSerializer_serializeIntegrator(const OpenMM_Integrator* integrator);
...@@ -804,6 +805,10 @@ OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirectory(const ...@@ -804,6 +805,10 @@ OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_loadPluginsFromDirectory(const
vector<string> result = Platform::loadPluginsFromDirectory(string(directory)); vector<string> result = Platform::loadPluginsFromDirectory(string(directory));
return reinterpret_cast<OpenMM_StringArray*>(new vector<string>(result)); return reinterpret_cast<OpenMM_StringArray*>(new vector<string>(result));
} }
OPENMM_EXPORT OpenMM_StringArray* OpenMM_Platform_getPluginLoadFailures() {
vector<string> result = Platform::getPluginLoadFailures();
return reinterpret_cast<OpenMM_StringArray*>(new vector<string>(result));
}
static char* createStringFromStream(stringstream& stream) { static char* createStringFromStream(stringstream& stream) {
int length = stream.str().size(); int length = stream.str().size();
char* result = (char*) malloc(length+1); char* result = (char*) malloc(length+1);
...@@ -1314,6 +1319,10 @@ MODULE OpenMM ...@@ -1314,6 +1319,10 @@ MODULE OpenMM
character(*) directory character(*) directory
type(OpenMM_StringArray) result type(OpenMM_StringArray) result
end subroutine end subroutine
subroutine OpenMM_Platform_getPluginLoadFailures(result)
use OpenMM_Types; implicit none
type(OpenMM_StringArray) result
end subroutine
subroutine OpenMM_XmlSerializer_serializeSystemToC(system, result, result_length) subroutine OpenMM_XmlSerializer_serializeSystemToC(system, result, result_length)
use iso_c_binding; use OpenMM_Types; implicit none use iso_c_binding; use OpenMM_Types; implicit none
type(OpenMM_System), intent(in) :: system type(OpenMM_System), intent(in) :: system
...@@ -1988,6 +1997,12 @@ OPENMM_EXPORT void openmm_platform_loadpluginsfromdirectory_(const char* directo ...@@ -1988,6 +1997,12 @@ OPENMM_EXPORT void openmm_platform_loadpluginsfromdirectory_(const char* directo
OPENMM_EXPORT void OPENMM_PLATFORM_LOADPLUGINSFROMDIRECTORY(const char* directory, OpenMM_StringArray*& result, int length) { OPENMM_EXPORT void OPENMM_PLATFORM_LOADPLUGINSFROMDIRECTORY(const char* directory, OpenMM_StringArray*& result, int length) {
result = OpenMM_Platform_loadPluginsFromDirectory(makeString(directory, length).c_str()); result = OpenMM_Platform_loadPluginsFromDirectory(makeString(directory, length).c_str());
} }
OPENMM_EXPORT void openmm_platform_getpluginloadfailures_(OpenMM_StringArray*& result) {
result = OpenMM_Platform_getPluginLoadFailures();
}
OPENMM_EXPORT void OPENMM_PLATFORM_GETPLUGINLOADFAILURES(OpenMM_StringArray*& result) {
result = OpenMM_Platform_getPluginLoadFailures();
}
OPENMM_EXPORT void openmm_xmlserializer_serializesystemtoc_(OpenMM_System*& system, char*& result, int& result_length) { OPENMM_EXPORT void openmm_xmlserializer_serializesystemtoc_(OpenMM_System*& system, char*& result, int& result_length) {
convertStringToChars(OpenMM_XmlSerializer_serializeSystem(system), result, result_length); convertStringToChars(OpenMM_XmlSerializer_serializeSystem(system), result, result_length);
} }
......
...@@ -1341,7 +1341,6 @@ class GBVIGenerator: ...@@ -1341,7 +1341,6 @@ class GBVIGenerator:
for bond in data.bonds: for bond in data.bonds:
type1 = data.atomType[data.atoms[bond.atom1]] type1 = data.atomType[data.atoms[bond.atom1]]
type2 = data.atomType[data.atoms[bond.atom2]] type2 = data.atomType[data.atoms[bond.atom2]]
hit = 0
for i in range(len(hbGenerator.types1)): for i in range(len(hbGenerator.types1)):
types1 = hbGenerator.types1[i] types1 = hbGenerator.types1[i]
types2 = hbGenerator.types2[i] types2 = hbGenerator.types2[i]
...@@ -1945,23 +1944,17 @@ class AmoebaBondGenerator: ...@@ -1945,23 +1944,17 @@ class AmoebaBondGenerator:
for bond in data.bonds: for bond in data.bonds:
type1 = data.atomType[data.atoms[bond.atom1]] type1 = data.atomType[data.atoms[bond.atom1]]
type2 = data.atomType[data.atoms[bond.atom2]] type2 = data.atomType[data.atoms[bond.atom2]]
hit = 0
for i in range(len(self.types1)): for i in range(len(self.types1)):
types1 = self.types1[i] types1 = self.types1[i]
types2 = self.types2[i] types2 = self.types2[i]
if (type1 in types1 and type2 in types2) or (type1 in types2 and type2 in types1): if (type1 in types1 and type2 in types2) or (type1 in types2 and type2 in types1):
bond.length = self.length[i] bond.length = self.length[i]
hit = 1
if bond.isConstrained: if bond.isConstrained:
sys.addConstraint(bond.atom1, bond.atom2, self.length[i]) sys.addConstraint(bond.atom1, bond.atom2, self.length[i])
elif self.k[i] != 0: elif self.k[i] != 0:
force.addBond(bond.atom1, bond.atom2, self.length[i], self.k[i]) force.addBond(bond.atom1, bond.atom2, self.length[i], self.k[i])
break break
if (hit == 0):
outputString = "AmoebaBondGenerator missing: types=[%5s %5s] atoms=[%6d %6d] " % (type1, type2, bond.atom1, bond.atom2)
raise ValueError(outputString)
parsers["AmoebaBondForce"] = AmoebaBondGenerator.parseElement parsers["AmoebaBondForce"] = AmoebaBondGenerator.parseElement
#============================================================================================= #=============================================================================================
...@@ -2093,13 +2086,11 @@ class AmoebaAngleGenerator: ...@@ -2093,13 +2086,11 @@ class AmoebaAngleGenerator:
type1 = data.atomType[data.atoms[angle[0]]] type1 = data.atomType[data.atoms[angle[0]]]
type2 = data.atomType[data.atoms[angle[1]]] type2 = data.atomType[data.atoms[angle[1]]]
type3 = data.atomType[data.atoms[angle[2]]] type3 = data.atomType[data.atoms[angle[2]]]
hit = 0
for i in range(len(self.types1)): for i in range(len(self.types1)):
types1 = self.types1[i] types1 = self.types1[i]
types2 = self.types2[i] types2 = self.types2[i]
types3 = self.types3[i] types3 = self.types3[i]
if (type1 in types1 and type2 in types2 and type3 in types3) or (type1 in types3 and type2 in types2 and type3 in types1): if (type1 in types1 and type2 in types2 and type3 in types3) or (type1 in types3 and type2 in types2 and type3 in types1):
hit = 1
if isConstrained and self.k[i] != 0.0: if isConstrained and self.k[i] != 0.0:
angleDict['idealAngle'] = self.angle[i][0] angleDict['idealAngle'] = self.angle[i][0]
addAngleConstraint(angle, self.angle[i][0]*math.pi/180.0, data, sys) addAngleConstraint(angle, self.angle[i][0]*math.pi/180.0, data, sys)
...@@ -2127,13 +2118,6 @@ class AmoebaAngleGenerator: ...@@ -2127,13 +2118,6 @@ class AmoebaAngleGenerator:
angleDict['idealAngle'] = angleValue angleDict['idealAngle'] = angleValue
force.addAngle(angle[0], angle[1], angle[2], angleValue, self.k[i]) force.addAngle(angle[0], angle[1], angle[2], angleValue, self.k[i])
break break
if (hit == 0):
outputString = "AmoebaAngleGenerator missing types: [%s %s %s] for atoms: " % (type1, type2, type3)
outputString += getAtomPrint( data, angle[0] ) + ' '
outputString += getAtomPrint( data, angle[1] ) + ' '
outputString += getAtomPrint( data, angle[2] )
outputString += " indices: [%6d %6d %6d]" % (angle[0], angle[1], angle[2])
raise ValueError(outputString)
#============================================================================================= #=============================================================================================
# createForcePostOpBendInPlaneAngle is called by AmoebaOutOfPlaneBendForce with the list of # createForcePostOpBendInPlaneAngle is called by AmoebaOutOfPlaneBendForce with the list of
...@@ -2169,7 +2153,6 @@ class AmoebaAngleGenerator: ...@@ -2169,7 +2153,6 @@ class AmoebaAngleGenerator:
type2 = data.atomType[data.atoms[angle[1]]] type2 = data.atomType[data.atoms[angle[1]]]
type3 = data.atomType[data.atoms[angle[2]]] type3 = data.atomType[data.atoms[angle[2]]]
hit = 0
for i in range(len(self.types1)): for i in range(len(self.types1)):
types1 = self.types1[i] types1 = self.types1[i]
...@@ -2177,7 +2160,6 @@ class AmoebaAngleGenerator: ...@@ -2177,7 +2160,6 @@ class AmoebaAngleGenerator:
types3 = self.types3[i] types3 = self.types3[i]
if (type1 in types1 and type2 in types2 and type3 in types3) or (type1 in types3 and type2 in types2 and type3 in types1): if (type1 in types1 and type2 in types2 and type3 in types3) or (type1 in types3 and type2 in types2 and type3 in types1):
hit = 1
angleDict['idealAngle'] = self.angle[i][0] angleDict['idealAngle'] = self.angle[i][0]
if (isConstrained and self.k[i] != 0.0): if (isConstrained and self.k[i] != 0.0):
addAngleConstraint(angle, self.angle[i][0]*math.pi/180.0, data, sys) addAngleConstraint(angle, self.angle[i][0]*math.pi/180.0, data, sys)
...@@ -2185,14 +2167,6 @@ class AmoebaAngleGenerator: ...@@ -2185,14 +2167,6 @@ class AmoebaAngleGenerator:
force.addAngle(angle[0], angle[1], angle[2], angle[3], self.angle[i][0], self.k[i]) force.addAngle(angle[0], angle[1], angle[2], angle[3], self.angle[i][0], self.k[i])
break break
if (hit == 0):
outputString = "AmoebaInPlaneAngleGenerator missing types: [%s %s %s] atoms: " % (type1, type2, type3)
outputString += getAtomPrint( data, angle[0] ) + ' '
outputString += getAtomPrint( data, angle[1] ) + ' '
outputString += getAtomPrint( data, angle[2] )
outputString += " indices: [%6d %6d %6d]" % (angle[0], angle[1], angle[2])
raise ValueError(outputString)
parsers["AmoebaAngleForce"] = AmoebaAngleGenerator.parseElement parsers["AmoebaAngleForce"] = AmoebaAngleGenerator.parseElement
#============================================================================================= #=============================================================================================
...@@ -2562,7 +2536,6 @@ class AmoebaTorsionGenerator: ...@@ -2562,7 +2536,6 @@ class AmoebaTorsionGenerator:
type3 = data.atomType[data.atoms[torsion[2]]] type3 = data.atomType[data.atoms[torsion[2]]]
type4 = data.atomType[data.atoms[torsion[3]]] type4 = data.atomType[data.atoms[torsion[3]]]
hit = 0
for i in range(len(self.types1)): for i in range(len(self.types1)):
types1 = self.types1[i] types1 = self.types1[i]
...@@ -2573,7 +2546,6 @@ class AmoebaTorsionGenerator: ...@@ -2573,7 +2546,6 @@ class AmoebaTorsionGenerator:
# match types in forward or reverse direction # match types in forward or reverse direction
if (type1 in types1 and type2 in types2 and type3 in types3 and type4 in types4) or (type4 in types1 and type3 in types2 and type2 in types3 and type1 in types4): if (type1 in types1 and type2 in types2 and type3 in types3 and type4 in types4) or (type4 in types1 and type3 in types2 and type2 in types3 and type1 in types4):
hit = 1
if self.t1[i][0] != 0: if self.t1[i][0] != 0:
force.addTorsion(torsion[0], torsion[1], torsion[2], torsion[3], 1, self.t1[i][1], self.t1[i][0]) force.addTorsion(torsion[0], torsion[1], torsion[2], torsion[3], 1, self.t1[i][1], self.t1[i][0])
if self.t2[i][0] != 0: if self.t2[i][0] != 0:
...@@ -2582,15 +2554,6 @@ class AmoebaTorsionGenerator: ...@@ -2582,15 +2554,6 @@ class AmoebaTorsionGenerator:
force.addTorsion(torsion[0], torsion[1], torsion[2], torsion[3], 3, self.t3[i][1], self.t3[i][0]) force.addTorsion(torsion[0], torsion[1], torsion[2], torsion[3], 3, self.t3[i][1], self.t3[i][0])
break break
if (hit == 0):
outputString = "AmoebaTorsionGenerator missing type: [%s %s %s %s] atoms: " % (type1, type2, type3, type4)
outputString += getAtomPrint( data, torsion[0] ) + ' '
outputString += getAtomPrint( data, torsion[1] ) + ' '
outputString += getAtomPrint( data, torsion[2] ) + ' '
outputString += getAtomPrint( data, torsion[3] )
outputString += " indices: [%6d %6d %6d %6d]" % (torsion[0], torsion[1], torsion[2], torsion[3])
raise ValueError(outputString)
parsers["AmoebaTorsionForce"] = AmoebaTorsionGenerator.parseElement parsers["AmoebaTorsionForce"] = AmoebaTorsionGenerator.parseElement
#============================================================================================= #=============================================================================================
......
...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of ...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org. Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2012-2014 University of Virginia and the Authors. Portions copyright (c) 2012-2015 University of Virginia and the Authors.
Authors: Christoph Klein, Michael R. Shirts Authors: Christoph Klein, Michael R. Shirts
Contributors: Jason M. Swails, Peter Eastman Contributors: Jason M. Swails, Peter Eastman
...@@ -31,7 +31,7 @@ USE OR OTHER DEALINGS IN THE SOFTWARE. ...@@ -31,7 +31,7 @@ USE OR OTHER DEALINGS IN THE SOFTWARE.
from __future__ import division from __future__ import division
from simtk.openmm import CustomGBForce, Discrete1DFunction from simtk.openmm import CustomGBForce, Continuous2DFunction
d0=[2.26685,2.32548,2.38397,2.44235,2.50057,2.55867,2.61663,2.67444, d0=[2.26685,2.32548,2.38397,2.44235,2.50057,2.55867,2.61663,2.67444,
2.73212,2.78965,2.84705,2.9043,2.96141,3.0184,3.07524,3.13196, 2.73212,2.78965,2.84705,2.9043,2.96141,3.0184,3.07524,3.13196,
...@@ -308,12 +308,11 @@ def GBSAGBnForce(solventDielectric=78.5, soluteDielectric=1, SA=None, ...@@ -308,12 +308,11 @@ def GBSAGBnForce(solventDielectric=78.5, soluteDielectric=1, SA=None,
custom.addPerParticleParameter("or") # Offset radius custom.addPerParticleParameter("or") # Offset radius
custom.addPerParticleParameter("sr") # Scaled offset radius custom.addPerParticleParameter("sr") # Scaled offset radius
custom.addTabulatedFunction("getd0", Discrete1DFunction(d0)) custom.addTabulatedFunction("getd0", Continuous2DFunction(21, 21, d0, 0.1, 0.2, 0.1, 0.2))
custom.addTabulatedFunction("getm0", Discrete1DFunction(m0)) custom.addTabulatedFunction("getm0", Continuous2DFunction(21, 21, m0, 0.1, 0.2, 0.1, 0.2))
custom.addComputedValue("I", "Ivdw+neckScale*Ineck;" custom.addComputedValue("I", "Ivdw+neckScale*Ineck;"
"Ineck=step(radius1+radius2+neckCut-r)*getm0(index)/(1+100*(r-getd0(index))^2+0.3*1000000*(r-getd0(index))^6);" "Ineck=step(radius1+radius2+neckCut-r)*getm0(radius1,radius2)/(1+100*(r-getd0(radius1,radius2))^2+0.3*1000000*(r-getd0(radius1,radius2))^6);"
"index = (radius2*200-20)*21 + (radius1*200-20);"
"Ivdw=step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(r-sr2^2/r)*(1/(U^2)-1/(L^2))+0.5*log(L/U)/r);" "Ivdw=step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(r-sr2^2/r)*(1/(U^2)-1/(L^2))+0.5*log(L/U)/r);"
"U=r+sr2;" "U=r+sr2;"
"L=max(or1, D);" "L=max(or1, D);"
...@@ -350,12 +349,11 @@ def GBSAGBn2Force(solventDielectric=78.5, soluteDielectric=1, SA=None, ...@@ -350,12 +349,11 @@ def GBSAGBn2Force(solventDielectric=78.5, soluteDielectric=1, SA=None,
custom.addPerParticleParameter("beta") custom.addPerParticleParameter("beta")
custom.addPerParticleParameter("gamma") custom.addPerParticleParameter("gamma")
custom.addTabulatedFunction("getd0", Discrete1DFunction(d0)) custom.addTabulatedFunction("getd0", Continuous2DFunction(21, 21, d0, 0.1, 0.2, 0.1, 0.2))
custom.addTabulatedFunction("getm0", Discrete1DFunction(m0)) custom.addTabulatedFunction("getm0", Continuous2DFunction(21, 21, m0, 0.1, 0.2, 0.1, 0.2))
custom.addComputedValue("I", "Ivdw+neckScale*Ineck;" custom.addComputedValue("I", "Ivdw+neckScale*Ineck;"
"Ineck=step(radius1+radius2+neckCut-r)*getm0(index)/(1+100*(r-getd0(index))^2+0.3*1000000*(r-getd0(index))^6);" "Ineck=step(radius1+radius2+neckCut-r)*getm0(radius1,radius2)/(1+100*(r-getd0(radius1,radius2))^2+0.3*1000000*(r-getd0(radius1,radius2))^6);"
"index = (radius2*200-20)*21 + (radius1*200-20);"
"Ivdw=step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(r-sr2^2/r)*(1/(U^2)-1/(L^2))+0.5*log(L/U)/r);" "Ivdw=step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(r-sr2^2/r)*(1/(U^2)-1/(L^2))+0.5*log(L/U)/r);"
"U=r+sr2;" "U=r+sr2;"
"L=max(or1, D);" "L=max(or1, D);"
......
...@@ -135,7 +135,7 @@ class Simulation(object): ...@@ -135,7 +135,7 @@ class Simulation(object):
if endStep is None: if endStep is None:
endStep = sys.maxint endStep = sys.maxint
nextReport = [None]*len(self.reporters) nextReport = [None]*len(self.reporters)
while self.currentStep < endStep: while self.currentStep < endStep and (endTime is None or datetime.now() < endTime):
nextSteps = endStep-self.currentStep nextSteps = endStep-self.currentStep
anyReport = False anyReport = False
for i, reporter in enumerate(self.reporters): for i, reporter in enumerate(self.reporters):
......
...@@ -55,7 +55,7 @@ class StateDataReporter(object): ...@@ -55,7 +55,7 @@ class StateDataReporter(object):
""" """
def __init__(self, file, reportInterval, step=False, time=False, potentialEnergy=False, kineticEnergy=False, totalEnergy=False, temperature=False, volume=False, density=False, def __init__(self, file, reportInterval, step=False, time=False, potentialEnergy=False, kineticEnergy=False, totalEnergy=False, temperature=False, volume=False, density=False,
progress=False, remainingTime=False, speed=False, separator=',', systemMass=None, totalSteps=None): progress=False, remainingTime=False, speed=False, elapsedTime=False, separator=',', systemMass=None, totalSteps=None):
"""Create a StateDataReporter. """Create a StateDataReporter.
Parameters: Parameters:
...@@ -74,6 +74,7 @@ class StateDataReporter(object): ...@@ -74,6 +74,7 @@ class StateDataReporter(object):
- remainingTime (boolean=False) Whether to write an estimate of the remaining clock time until - remainingTime (boolean=False) Whether to write an estimate of the remaining clock time until
completion to the file. If this is True, you must also specify totalSteps. completion to the file. If this is True, you must also specify totalSteps.
- speed (bool=False) Whether to write an estimate of the simulation speed in ns/day to the file - speed (bool=False) Whether to write an estimate of the simulation speed in ns/day to the file
- elapsedTime (bool=False) Whether to write the elapsed time of the simulation in seconds to the file.
- separator (string=',') The separator to use between columns in the file - separator (string=',') The separator to use between columns in the file
- systemMass (mass=None) The total mass to use for the system when reporting density. If this is - systemMass (mass=None) The total mass to use for the system when reporting density. If this is
None (the default), the system mass is computed by summing the masses of all particles. This None (the default), the system mass is computed by summing the masses of all particles. This
...@@ -113,6 +114,7 @@ class StateDataReporter(object): ...@@ -113,6 +114,7 @@ class StateDataReporter(object):
self._progress = progress self._progress = progress
self._remainingTime = remainingTime self._remainingTime = remainingTime
self._speed = speed self._speed = speed
self._elapsedTime = elapsedTime
self._separator = separator self._separator = separator
self._totalMass = systemMass self._totalMass = systemMass
self._totalSteps = totalSteps self._totalSteps = totalSteps
...@@ -207,6 +209,8 @@ class StateDataReporter(object): ...@@ -207,6 +209,8 @@ class StateDataReporter(object):
values.append('%.3g' % (elapsedNs/elapsedDays)) values.append('%.3g' % (elapsedNs/elapsedDays))
else: else:
values.append('--') values.append('--')
if self._elapsedTime:
values.append(time.time() - self._initialClockTime)
if self._remainingTime: if self._remainingTime:
elapsedSeconds = clockTime-self._initialClockTime elapsedSeconds = clockTime-self._initialClockTime
elapsedSteps = simulation.currentStep-self._initialSteps elapsedSteps = simulation.currentStep-self._initialSteps
...@@ -284,6 +288,8 @@ class StateDataReporter(object): ...@@ -284,6 +288,8 @@ class StateDataReporter(object):
headers.append('Density (g/mL)') headers.append('Density (g/mL)')
if self._speed: if self._speed:
headers.append('Speed (ns/day)') headers.append('Speed (ns/day)')
if self._elapsedTime:
headers.append('Elapsed Time (s)')
if self._remainingTime: if self._remainingTime:
headers.append('Time Remaining') headers.append('Time Remaining')
return headers return headers
......
...@@ -59,6 +59,14 @@ class Topology(object): ...@@ -59,6 +59,14 @@ class Topology(object):
self._bonds = [] self._bonds = []
self._periodicBoxVectors = None self._periodicBoxVectors = None
def __repr__(self):
nchains = len(self._chains)
nres = sum(1 for r in self.residues())
natom = sum(1 for a in self.atoms())
nbond = len(self._bonds)
return '<%s; %d chains, %d residues, %d atoms, %d bonds>' % (
type(self).__name__, nchains, nres, natom, nbond)
def addChain(self, id=None): def addChain(self, id=None):
"""Create a new Chain and add it to the Topology. """Create a new Chain and add it to the Topology.
...@@ -291,6 +299,9 @@ class Chain(object): ...@@ -291,6 +299,9 @@ class Chain(object):
for atom in residue._atoms: for atom in residue._atoms:
yield atom yield atom
def __len__(self):
return len(self._residues)
class Residue(object): class Residue(object):
"""A Residue object represents a residue within a Topology.""" """A Residue object represents a residue within a Topology."""
def __init__(self, name, index, chain, id): def __init__(self, name, index, chain, id):
...@@ -309,6 +320,9 @@ class Residue(object): ...@@ -309,6 +320,9 @@ class Residue(object):
"""Iterate over all Atoms in the Residue.""" """Iterate over all Atoms in the Residue."""
return iter(self._atoms) return iter(self._atoms)
def __len__(self):
return len(self._atoms)
class Atom(object): class Atom(object):
"""An Atom object represents a residue within a Topology.""" """An Atom object represents a residue within a Topology."""
......
...@@ -101,7 +101,7 @@ def run_tests(): ...@@ -101,7 +101,7 @@ def run_tests():
d = f1-f2 d = f1-f2
error = sqrt((d[0]*d[0]+d[1]*d[1]+d[2]*d[2])/(f1[0]*f1[0]+f1[1]*f1[1]+f1[2]*f1[2])) error = sqrt((d[0]*d[0]+d[1]*d[1]+d[2]*d[2])/(f1[0]*f1[0]+f1[1]*f1[1]+f1[2]*f1[2]))
errors.append(error) errors.append(error)
print("{} vs. {}: {:g}".format(Platform.getPlatform(j).getName(), print("{0} vs. {1}: {2:g}".format(Platform.getPlatform(j).getName(),
Platform.getPlatform(i).getName(), Platform.getPlatform(i).getName(),
sorted(errors)[len(errors)//2])) sorted(errors)[len(errors)//2]))
......
#!/bin/env python #!/bin/env python
""" """
Module simtk.unit.basedimension Module simtk.unit.basedimension
...@@ -34,6 +32,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR ...@@ -34,6 +32,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE. USE OR OTHER DEALINGS IN THE SOFTWARE.
""" """
from __future__ import print_function, division
__author__ = "Christopher M. Bruns" __author__ = "Christopher M. Bruns"
__version__ = "0.6" __version__ = "0.6"
......
...@@ -33,6 +33,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR ...@@ -33,6 +33,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE. USE OR OTHER DEALINGS IN THE SOFTWARE.
""" """
from __future__ import print_function, division, absolute_import
__author__ = "Christopher M. Bruns" __author__ = "Christopher M. Bruns"
__version__ = "0.6" __version__ = "0.6"
...@@ -44,6 +45,7 @@ class BaseUnit(object): ...@@ -44,6 +45,7 @@ class BaseUnit(object):
For example, meter_base_unit could be a BaseUnit for the length dimension. For example, meter_base_unit could be a BaseUnit for the length dimension.
The BaseUnit class is used internally in the more general Unit class. The BaseUnit class is used internally in the more general Unit class.
''' '''
__array_priority__ = 100
def __init__(self, base_dim, name, symbol): def __init__(self, base_dim, name, symbol):
"""Creates a new BaseUnit. """Creates a new BaseUnit.
...@@ -127,7 +129,7 @@ class BaseUnit(object): ...@@ -127,7 +129,7 @@ class BaseUnit(object):
self._conversion_factor_to_by_name[other.name] = factor self._conversion_factor_to_by_name[other.name] = factor
for (unit, cfac) in other._conversion_factor_to.items(): for (unit, cfac) in other._conversion_factor_to.items():
if unit is self: continue if unit is self: continue
if self._conversion_factor_to.has_key(unit): continue if unit in self._conversion_factor_to: continue
self._conversion_factor_to[unit] = factor * cfac self._conversion_factor_to[unit] = factor * cfac
unit._conversion_factor_to[self] = pow(factor * cfac, -1) unit._conversion_factor_to[self] = pow(factor * cfac, -1)
self._conversion_factor_to_by_name[unit.name] = factor * cfac self._conversion_factor_to_by_name[unit.name] = factor * cfac
...@@ -138,7 +140,7 @@ class BaseUnit(object): ...@@ -138,7 +140,7 @@ class BaseUnit(object):
other._conversion_factor_to_by_name[self.name] = invFac other._conversion_factor_to_by_name[self.name] = invFac
for (unit, cfac) in self._conversion_factor_to.items(): for (unit, cfac) in self._conversion_factor_to.items():
if unit is other: continue if unit is other: continue
if other._conversion_factor_to.has_key(unit): continue if unit in other._conversion_factor_to: continue
other._conversion_factor_to[unit] = invFac * cfac other._conversion_factor_to[unit] = invFac * cfac
unit._conversion_factor_to[other] = pow(invFac * cfac, -1) unit._conversion_factor_to[other] = pow(invFac * cfac, -1)
other._conversion_factor_to_by_name[unit.name] = invFac * cfac other._conversion_factor_to_by_name[unit.name] = invFac * cfac
......
...@@ -29,13 +29,12 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR ...@@ -29,13 +29,12 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE. USE OR OTHER DEALINGS IN THE SOFTWARE.
""" """
from __future__ import print_function, division, absolute_import
from __future__ import division
__author__ = "Christopher M. Bruns" __author__ = "Christopher M. Bruns"
__version__ = "0.5" __version__ = "0.5"
from unit_definitions import * from .unit_definitions import *
################# #################
### CONSTANTS ### ### CONSTANTS ###
......
...@@ -28,6 +28,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR ...@@ -28,6 +28,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE. USE OR OTHER DEALINGS IN THE SOFTWARE.
""" """
from __future__ import print_function, division, absolute_import
import sys import sys
...@@ -41,9 +42,9 @@ def eye(size): ...@@ -41,9 +42,9 @@ def eye(size):
[0, 0, 1]] [0, 0, 1]]
""" """
result = [] result = []
for row in range(0, size): for row in range(size):
r = [] r = []
for col in range(0, size): for col in range(size):
if row == col: if row == col:
r.append(1) r.append(1)
else: else:
...@@ -63,9 +64,9 @@ def zeros(m, n=None): ...@@ -63,9 +64,9 @@ def zeros(m, n=None):
if n is None: if n is None:
n = m n = m
result = [] result = []
for row in range(0, m): for row in range(m):
r = [] r = []
for col in range(0, n): for col in range(n):
r.append(0) r.append(0)
result.append(r) result.append(r)
return MyMatrix(result) return MyMatrix(result)
...@@ -171,7 +172,7 @@ class MyMatrix(MyVector): ...@@ -171,7 +172,7 @@ class MyMatrix(MyVector):
def __str__(self): def __str__(self):
result = "" result = ""
start_char = "[" start_char = "["
for m in range(0, self.numRows()): for m in range(self.numRows()):
result += start_char result += start_char
result += str(self[m]) result += str(self[m])
if m < self.numRows() - 1: if m < self.numRows() - 1:
...@@ -226,9 +227,9 @@ class MyMatrix(MyVector): ...@@ -226,9 +227,9 @@ class MyMatrix(MyVector):
if self.numCols() != r: if self.numCols() != r:
raise ArithmeticError("Matrix multplication size mismatch (%d vs %d)" % (self.numCols(), r)) raise ArithmeticError("Matrix multplication size mismatch (%d vs %d)" % (self.numCols(), r))
result = zeros(m, n) result = zeros(m, n)
for i in range(0, m): for i in range(m):
for j in range(0, n): for j in range(n):
for k in range(0, r): for k in range(r):
result[i][j] += self[i][k]*rhs[k][j] result[i][j] += self[i][k]*rhs[k][j]
return result return result
...@@ -245,8 +246,8 @@ class MyMatrix(MyVector): ...@@ -245,8 +246,8 @@ class MyMatrix(MyVector):
assert len(rhs) == m assert len(rhs) == m
assert len(rhs[0]) == n assert len(rhs[0]) == n
result = zeros(m,n) result = zeros(m,n)
for i in range(0,m): for i in range(m):
for j in range(0,n): for j in range(n):
result[i][j] = self[i][j] + rhs[i][j] result[i][j] = self[i][j] + rhs[i][j]
return result return result
...@@ -263,8 +264,8 @@ class MyMatrix(MyVector): ...@@ -263,8 +264,8 @@ class MyMatrix(MyVector):
assert len(rhs) == m assert len(rhs) == m
assert len(rhs[0]) == n assert len(rhs[0]) == n
result = zeros(m,n) result = zeros(m,n)
for i in range(0,m): for i in range(m):
for j in range(0,n): for j in range(n):
result[i][j] = self[i][j] - rhs[i][j] result[i][j] = self[i][j] - rhs[i][j]
return result return result
...@@ -275,8 +276,8 @@ class MyMatrix(MyVector): ...@@ -275,8 +276,8 @@ class MyMatrix(MyVector):
m = self.numRows() m = self.numRows()
n = self.numCols() n = self.numCols()
result = zeros(m, n) result = zeros(m, n)
for i in range(0,m): for i in range(m):
for j in range(0,n): for j in range(n):
result[i][j] = -self[i][j] result[i][j] = -self[i][j]
return result return result
...@@ -358,7 +359,7 @@ class MyMatrix(MyVector): ...@@ -358,7 +359,7 @@ class MyMatrix(MyVector):
ipiv[icol] += 1 ipiv[icol] += 1
# We now have the pivot element, so we interchange rows... # We now have the pivot element, so we interchange rows...
if irow != icol: if irow != icol:
for l in range(0,n): for l in range(n):
temp = a[irow][l] temp = a[irow][l]
a[irow][l] = a[icol][l] a[irow][l] = a[icol][l]
a[icol][l] = temp a[icol][l] = temp
...@@ -368,20 +369,20 @@ class MyMatrix(MyVector): ...@@ -368,20 +369,20 @@ class MyMatrix(MyVector):
raise ArithmeticError("Cannot invert singular matrix") raise ArithmeticError("Cannot invert singular matrix")
pivinv = 1.0/a[icol][icol] pivinv = 1.0/a[icol][icol]
a[icol][icol] = 1.0 a[icol][icol] = 1.0
for l in range(0,n): for l in range(n):
a[icol][l] *= pivinv a[icol][l] *= pivinv
for ll in range(0,n): # next we reduce the rows for ll in range(n): # next we reduce the rows
if ll == icol: if ll == icol:
continue # except the pivot one, of course continue # except the pivot one, of course
dum = a[ll][icol] dum = a[ll][icol]
a[ll][icol] = 0.0 a[ll][icol] = 0.0
for l in range(0,n): for l in range(n):
a[ll][l] -= a[icol][l]*dum a[ll][l] -= a[icol][l]*dum
# Unscramble the permuted columns # Unscramble the permuted columns
for l in range(n-1, -1, -1): for l in range(n-1, -1, -1):
if indxr[l] == indxc[l]: if indxr[l] == indxc[l]:
continue continue
for k in range(0,n): for k in range(n):
temp = a[k][indxr[l]] temp = a[k][indxr[l]]
a[k][indxr[l]] = a[k][indxc[l]] a[k][indxr[l]] = a[k][indxc[l]]
a[k][indxc[l]] = temp a[k][indxc[l]] = temp
...@@ -415,7 +416,7 @@ class MyMatrixTranspose(MyMatrix): ...@@ -415,7 +416,7 @@ class MyMatrixTranspose(MyMatrix):
return MyVector(result) return MyVector(result)
def __setitem__(self, key, rhs): def __setitem__(self, key, rhs):
for n in range(0, len(self.data)): for n in range(len(self.data)):
self.data[n][key] = rhs[n] self.data[n][key] = rhs[n]
def __str__(self): def __str__(self):
...@@ -423,11 +424,11 @@ class MyMatrixTranspose(MyMatrix): ...@@ -423,11 +424,11 @@ class MyMatrixTranspose(MyMatrix):
return "[[]]" return "[[]]"
start_char = "[" start_char = "["
result = "" result = ""
for m in range(0, len(self.data[0])): for m in range(len(self.data[0])):
result += start_char result += start_char
result += "[" result += "["
sep_char = "" sep_char = ""
for n in range(0, len(self.data)): for n in range(len(self.data)):
result += sep_char result += sep_char
result += str(self.data[n][m]) result += str(self.data[n][m])
sep_char = ", " sep_char = ", "
...@@ -443,11 +444,11 @@ class MyMatrixTranspose(MyMatrix): ...@@ -443,11 +444,11 @@ class MyMatrixTranspose(MyMatrix):
return "MyMatrixTranspose([[]])" return "MyMatrixTranspose([[]])"
start_char = "[" start_char = "["
result = 'MyMatrixTranspose(' result = 'MyMatrixTranspose('
for m in range(0, len(self.data[0])): for m in range(len(self.data[0])):
result += start_char result += start_char
result += "[" result += "["
sep_char = "" sep_char = ""
for n in range(0, len(self.data)): for n in range(len(self.data)):
result += sep_char result += sep_char
result += repr(self.data[n][m]) result += repr(self.data[n][m])
sep_char = ", " sep_char = ", "
......
...@@ -29,12 +29,13 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR ...@@ -29,12 +29,13 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE. USE OR OTHER DEALINGS IN THE SOFTWARE.
""" """
from __future__ import print_function, division, absolute_import
__author__ = "Christopher M. Bruns" __author__ = "Christopher M. Bruns"
__version__ = "0.6" __version__ = "0.6"
from baseunit import BaseUnit from .baseunit import BaseUnit
from unit import Unit, ScaledUnit from .unit import Unit, ScaledUnit
import sys import sys
################### ###################
......
...@@ -67,8 +67,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR ...@@ -67,8 +67,7 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE. USE OR OTHER DEALINGS IN THE SOFTWARE.
""" """
from __future__ import division, print_function, absolute_import
from __future__ import division
__author__ = "Christopher M. Bruns" __author__ = "Christopher M. Bruns"
__version__ = "0.5" __version__ = "0.5"
...@@ -76,8 +75,8 @@ __version__ = "0.5" ...@@ -76,8 +75,8 @@ __version__ = "0.5"
import math import math
import copy import copy
from standard_dimensions import * from .standard_dimensions import *
from unit import Unit, is_unit, dimensionless from .unit import Unit, is_unit, dimensionless
class Quantity(object): class Quantity(object):
"""Physical quantity, such as 1.3 meters per second. """Physical quantity, such as 1.3 meters per second.
...@@ -92,18 +91,8 @@ class Quantity(object): ...@@ -92,18 +91,8 @@ class Quantity(object):
Note - unit conversions will cause tuples to be converted to lists Note - unit conversions will cause tuples to be converted to lists
4 - lists of tuples of numbers, lists of lists of ... etc. of numbers 4 - lists of tuples of numbers, lists of lists of ... etc. of numbers
5 - numpy.arrays 5 - numpy.arrays
Create numpy.arrays with units using the Quantity constructor, not the
multiply operator. e.g.
Quantity(numpy.array([1,2,3]), centimeters) # correct
*NOT*
numpy.array([1,2,3]) * centimeters # won't work
because numpy.arrays already overload the multiply operator for EVERYTHING.
""" """
__array_priority__ = 99
def __init__(self, value=None, unit=None): def __init__(self, value=None, unit=None):
""" """
...@@ -136,7 +125,7 @@ class Quantity(object): ...@@ -136,7 +125,7 @@ class Quantity(object):
if len(value) < 1: if len(value) < 1:
unit = dimensionless unit = dimensionless
else: else:
first_item = iter(value).next() first_item = next(iter(value))
# Avoid infinite recursion for string, because a one-character # Avoid infinite recursion for string, because a one-character
# string is its own first element # string is its own first element
try: try:
...@@ -613,6 +602,9 @@ class Quantity(object): ...@@ -613,6 +602,9 @@ class Quantity(object):
""" """
return bool(self._value) return bool(self._value)
def __bool__(self):
return bool(self._value)
def __complex__(self): def __complex__(self):
return Quantity(complex(self._value), self.unit) return Quantity(complex(self._value), self.unit)
def __float__(self): def __float__(self):
...@@ -713,7 +705,7 @@ class Quantity(object): ...@@ -713,7 +705,7 @@ class Quantity(object):
else: else:
for i in range(len(value)): for i in range(len(value)):
value[i] = factor*value[i] value[i] = factor*value[i]
except TypeError as ex: except TypeError:
if isinstance(value, tuple): if isinstance(value, tuple):
value = tuple([self._scale_sequence(x, factor, post_multiply) for x in value]) value = tuple([self._scale_sequence(x, factor, post_multiply) for x in value])
else: else:
...@@ -826,7 +818,6 @@ def _is_string(x): ...@@ -826,7 +818,6 @@ def _is_string(x):
except StopIteration: except StopIteration:
return False return False
# run module directly for testing # run module directly for testing
if __name__=='__main__': if __name__=='__main__':
# Test the examples in the docstrings # Test the examples in the docstrings
......
...@@ -31,11 +31,12 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR ...@@ -31,11 +31,12 @@ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
USE OR OTHER DEALINGS IN THE SOFTWARE. USE OR OTHER DEALINGS IN THE SOFTWARE.
""" """
from __future__ import division, print_function, absolute_import
__author__ = "Christopher M. Bruns" __author__ = "Christopher M. Bruns"
__version__ = "0.6" __version__ = "0.6"
from basedimension import BaseDimension from .basedimension import BaseDimension
################## ##################
### DIMENSIONS ### ### DIMENSIONS ###
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment