"tests/data/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "de767fc126d994586b2c460a14a43e1a393be573"
Unverified Commit cbfbab3f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Change comp6 script a bit (#211)

parent 077c12b8
...@@ -13,7 +13,7 @@ dtype = torch.float32 ...@@ -13,7 +13,7 @@ dtype = torch.float32
# parse command line arguments # parse command line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('dir', help='Path to the COMP6 directory') parser.add_argument('dir', help='Path to the COMP6 directory')
parser.add_argument('-b', '--batchatoms', type=int, default=512, parser.add_argument('-b', '--batchatoms', type=int, default=4096,
help='Maximum number of ATOMs in each batch') help='Maximum number of ATOMs in each batch')
parser.add_argument('-d', '--device', parser.add_argument('-d', '--device',
help='Device of modules and tensors', help='Device of modules and tensors',
...@@ -42,8 +42,7 @@ def by_batch(species, coordinates, model): ...@@ -42,8 +42,7 @@ def by_batch(species, coordinates, model):
coordinates = torch.split(coordinates, batchsize) coordinates = torch.split(coordinates, batchsize)
energies = [] energies = []
forces = [] forces = []
for s, c in tqdm.tqdm(zip(species, coordinates), total=len(species), for s, c in zip(species, coordinates):
position=1, desc="batch of {}x{}".format(*shape)):
_, e = model((s, c)) _, e = model((s, c))
f, = torch.autograd.grad(e.sum(), c) f, = torch.autograd.grad(e.sum(), c)
energies.append(e) energies.append(e)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment