Unverified Commit ac27ed3d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Update model URL (#484)

* Update model URL

* Update models.py

* Update models.py

* fix

* cleanup

* save
parent 9e7baa6e
......@@ -26,7 +26,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# The ``periodic_table_index`` arguments tells TorchANI to use element index
# in periodic table to index species. If not specified, you need to use
# 0, 1, 2, 3, ... to index species
model = torchani.models.ANI1ccx(periodic_table_index=True).to(device)
model = torchani.models.ANI2x(periodic_table_index=True).to(device)
###############################################################################
# Now let's define the coordinate and species. If you just want to compute the
......
......@@ -26,7 +26,9 @@ directly calculate energies or get an ASE calculator. For example:
import os
import io
import requests
import glob
import zipfile
import shutil
import torch
from torch import Tensor
from typing import Tuple, Optional
......@@ -73,15 +75,19 @@ class BuiltinModel(torch.nn.Module):
@staticmethod
def _parse_neurochem_resources(info_file_path):
def get_resource(resource_path, file_path):
return os.path.join(resource_path, 'resources/' + file_path)
return os.path.join(resource_path, file_path)
resource_path = os.path.dirname(__file__)
local_dir = os.path.expanduser('~/.local/torchani')
resource_path = os.path.join(os.path.dirname(__file__), 'resources/')
local_dir = os.path.expanduser('~/.local/torchani/')
repo_name = "ani-model-zoo"
tag_name = "ani-2x"
extracted_name = '{}-{}'.format(repo_name, tag_name)
url = "https://github.com/aiqm/{}/archive/{}.zip".format(repo_name, tag_name)
if not os.path.isfile(get_resource(resource_path, info_file_path)):
if not os.path.isfile(get_resource(local_dir, info_file_path)):
print('Downloading ANI model parameters ...')
resource_res = requests.get("https://www.dropbox.com/sh/otrzul6yuye8uzs/AABuaihE22vtaB_rdrI0r6TUa?dl=1")
resource_res = requests.get(url)
resource_zip = zipfile.ZipFile(io.BytesIO(resource_res.content))
try:
resource_zip.extractall(resource_path)
......@@ -91,6 +97,14 @@ class BuiltinModel(torch.nn.Module):
else:
resource_path = local_dir
files = glob.glob(os.path.join(resource_path, extracted_name, "resources", "*"))
for f in files:
try:
shutil.move(f, resource_path)
except shutil.Error:
pass
shutil.rmtree(os.path.join(resource_path, extracted_name))
info_file = get_resource(resource_path, info_file_path)
with open(info_file) as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment