Unverified Commit ac27ed3d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Update model URL (#484)

* Update model URL

* Update models.py

* Update models.py

* fix

* cleanup

* save
parent 9e7baa6e
...@@ -26,7 +26,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ...@@ -26,7 +26,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# The ``periodic_table_index`` arguments tells TorchANI to use element index # The ``periodic_table_index`` arguments tells TorchANI to use element index
# in periodic table to index species. If not specified, you need to use # in periodic table to index species. If not specified, you need to use
# 0, 1, 2, 3, ... to index species # 0, 1, 2, 3, ... to index species
model = torchani.models.ANI1ccx(periodic_table_index=True).to(device) model = torchani.models.ANI2x(periodic_table_index=True).to(device)
############################################################################### ###############################################################################
# Now let's define the coordinate and species. If you just want to compute the # Now let's define the coordinate and species. If you just want to compute the
......
...@@ -26,7 +26,9 @@ directly calculate energies or get an ASE calculator. For example: ...@@ -26,7 +26,9 @@ directly calculate energies or get an ASE calculator. For example:
import os import os
import io import io
import requests import requests
import glob
import zipfile import zipfile
import shutil
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Tuple, Optional from typing import Tuple, Optional
...@@ -73,15 +75,19 @@ class BuiltinModel(torch.nn.Module): ...@@ -73,15 +75,19 @@ class BuiltinModel(torch.nn.Module):
@staticmethod @staticmethod
def _parse_neurochem_resources(info_file_path): def _parse_neurochem_resources(info_file_path):
def get_resource(resource_path, file_path): def get_resource(resource_path, file_path):
return os.path.join(resource_path, 'resources/' + file_path) return os.path.join(resource_path, file_path)
resource_path = os.path.dirname(__file__) resource_path = os.path.join(os.path.dirname(__file__), 'resources/')
local_dir = os.path.expanduser('~/.local/torchani') local_dir = os.path.expanduser('~/.local/torchani/')
repo_name = "ani-model-zoo"
tag_name = "ani-2x"
extracted_name = '{}-{}'.format(repo_name, tag_name)
url = "https://github.com/aiqm/{}/archive/{}.zip".format(repo_name, tag_name)
if not os.path.isfile(get_resource(resource_path, info_file_path)): if not os.path.isfile(get_resource(resource_path, info_file_path)):
if not os.path.isfile(get_resource(local_dir, info_file_path)): if not os.path.isfile(get_resource(local_dir, info_file_path)):
print('Downloading ANI model parameters ...') print('Downloading ANI model parameters ...')
resource_res = requests.get("https://www.dropbox.com/sh/otrzul6yuye8uzs/AABuaihE22vtaB_rdrI0r6TUa?dl=1") resource_res = requests.get(url)
resource_zip = zipfile.ZipFile(io.BytesIO(resource_res.content)) resource_zip = zipfile.ZipFile(io.BytesIO(resource_res.content))
try: try:
resource_zip.extractall(resource_path) resource_zip.extractall(resource_path)
...@@ -91,6 +97,14 @@ class BuiltinModel(torch.nn.Module): ...@@ -91,6 +97,14 @@ class BuiltinModel(torch.nn.Module):
else: else:
resource_path = local_dir resource_path = local_dir
files = glob.glob(os.path.join(resource_path, extracted_name, "resources", "*"))
for f in files:
try:
shutil.move(f, resource_path)
except shutil.Error:
pass
shutil.rmtree(os.path.join(resource_path, extracted_name))
info_file = get_resource(resource_path, info_file_path) info_file = get_resource(resource_path, info_file_path)
with open(info_file) as f: with open(info_file) as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment