Unverified Commit e3096aa1 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

fix deprecated in inference benchmark (#192)

parent d6932a16
......@@ -76,7 +76,7 @@ print()
# test batch mode
print('[Batch mode]')
species, coordinates = torch.utils.data.dataloader.default_collate(list(xyz))
coordinates = torch.tensor(coordinates, requires_grad=True)
coordinates.requires_grad_(True)
start = timeit.default_timer()
energies = nnp((species, coordinates))[1]
mid = timeit.default_timer()
......@@ -92,7 +92,7 @@ if parser.tqdm:
xyz = tqdm.tqdm(xyz)
for species, coordinates in xyz:
species = species.unsqueeze(0)
coordinates = torch.tensor(coordinates.unsqueeze(0), requires_grad=True)
coordinates = coordinates.unsqueeze(0).detach().requires_grad_(True)
energies = nnp((species, coordinates))[1]
force = -torch.autograd.grad(energies.sum(), coordinates)[0]
print('Time:', timeit.default_timer() - start)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment