testInstallation.py 1.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# First make sure OpenMM is installed.

import sys
try:
    from simtk.openmm.app import *
    from simtk.openmm import *
    from simtk.unit import *
except ImportError as err:
    print "Failed to import OpenMM packages:", err.message
    print "Make sure OpenMM is installed and the library path is set correctly."
    sys.exit()

# Create a System for the tests.

pdb = PDBFile('input.pdb')
forcefield = ForceField('amber99sb.xml', 'tip3p.xml')
system = forcefield.createSystem(pdb.topology, nonbondedMethod=PME, nonbondedCutoff=1*nanometer, constraints=HBonds)

# List all installed platforms and compute forces with each one.

numPlatforms = Platform.getNumPlatforms()
print "There are", numPlatforms, "Platforms available:"
print
forces = [None]*numPlatforms
for i in range(numPlatforms):
    platform = Platform.getPlatform(i)
27
    print i+1, platform.getName(),
28
29
    integrator = LangevinIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)
    try:
30
31
        simulation = Simulation(pdb.topology, system, integrator, platform)
        simulation.context.setPositions(pdb.positions)
32
        forces[i] = simulation.context.getState(getForces=True).getForces()
33
        del simulation
34
35
        print "- Successfully computed forces"
    except:
36
        print "- Error computing forces with", platform.getName(), "platform"
37
38
39

# See how well the platforms agree.

40
41
42
43
44
45
46
47
48
49
50
51
52
if numPlatforms > 1:
    print
    print "Median difference in forces between platforms:"
    print
    for i in range(numPlatforms):
        for j in range(i):
            if forces[i] is not None and forces[j] is not None:
                errors = []
                for f1, f2 in zip(forces[i], forces[j]):
                    d = f1-f2
                    error = sqrt((d[0]*d[0]+d[1]*d[1]+d[2]*d[2])/(f1[0]*f1[0]+f1[1]*f1[1]+f1[2]*f1[2]))
                    errors.append(error)
                print "%s vs. %s: %g" % (Platform.getPlatform(j).getName(), Platform.getPlatform(i).getName(), sorted(errors)[len(errors)/2])