TestATMForce.py 1.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import unittest
from openmm import *
from openmm.app import *
from openmm.unit import *

class TestATMForce(unittest.TestCase):
    """Tests the ATMForce"""

    def test2ParticlesNonbonded(self):
        """Test for a Nonbonded force previously added to the System"""
        system = System()
        system.addParticle(1.0)
        system.addParticle(1.0)

        nbforce = NonbondedForce();
        nbforce.addParticle( 1.0, 1.0, 1.0)
        nbforce.addParticle(-1.0, 1.0, 1.0)

        system.addForce(nbforce)

        atmforce = ATMForce(0.5, 0.5, 0, 0, 0,   0, 0, 0,  1.0)
        atmforce.addParticle(Vec3(0., 0., 0.))
        atmforce.addParticle(Vec3(1., 0., 0.))

        atmforce.addForce(copy.copy(nbforce))
        system.removeForce(0)
        system.addForce(atmforce)

        integrator = VerletIntegrator(1.0)
        platform = Platform.getPlatformByName('Reference')
        context = Context(system, integrator, platform)

        positions = []
        positions.append(Vec3(0., 0., 0.))
        positions.append(Vec3(1., 0., 0.))
        context.setPositions(positions)

        state = context.getState(getEnergy = True, getForces = True)
        epot = state.getPotentialEnergy()
        
        (u1, u0, energy) = atmforce.getPerturbationEnergy(context)
        epert = u1 - u0

        #print("Potential energy = ", epot)
        #print("ATM perturbation energy = ", epert)
        
        epot_expected = -104.2320*kilojoules_per_mole
        epert_expected = 69.4062*kilojoules_per_mole
        assert( abs(epot-epot_expected) < 1.e-3*kilojoules_per_mole )
        assert( abs(epert-epert_expected) < 1.e-3*kilojoules_per_mole )