Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
10699bf7
"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "a6ba254fa78b063f7367d2495b9bd4b64c1eb7db"
Unverified
Commit
10699bf7
authored
Jul 30, 2018
by
Gao, Xiang
Committed by
GitHub
Jul 30, 2018
Browse files
allow energy shifter as transformations to dataset (#30)
parent
18e4867d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
5 deletions
+17
-5
tests/test_ignite.py
tests/test_ignite.py
+4
-1
torchani/data.py
torchani/data.py
+6
-4
torchani/energyshifter.py
torchani/energyshifter.py
+7
-0
No files found.
tests/test_ignite.py
View file @
10699bf7
...
...
@@ -18,7 +18,10 @@ if sys.version_info.major >= 3:
class
TestIgnite
(
unittest
.
TestCase
):
def
testIgnite
(
self
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
shift_energy
=
torchani
.
EnergyShifter
()
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
loader
=
torchani
.
data
.
dataloader
(
ds
,
batch_chunks
)
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
)
...
...
torchani/data.py
View file @
10699bf7
...
...
@@ -8,8 +8,8 @@ import torch
class
ANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
chunk_size
,
shuffle
=
True
,
properties
=
[
'energies'
]
,
dtype
=
default_dtype
):
def
__init__
(
self
,
path
,
chunk_size
,
shuffle
=
True
,
properties
=
[
'energies'
],
transform
=
()
,
dtype
=
default_dtype
):
super
(
ANIDataset
,
self
).
__init__
()
self
.
path
=
path
self
.
chunks_size
=
chunk_size
...
...
@@ -54,6 +54,8 @@ class ANIDataset(Dataset):
for
j
in
full
:
chunk
[
j
]
=
full
[
j
].
index_select
(
0
,
chunk_indices
)
chunk
[
'species'
]
=
species
for
t
in
transform
:
chunk
=
t
(
chunk
)
chunks
.
append
(
chunk
)
self
.
chunks
=
chunks
...
...
@@ -80,6 +82,6 @@ def _collate(batch):
return
inputs
,
outputs
def
dataloader
(
dataset
,
batch_chunks
,
**
kwargs
):
return
DataLoader
(
dataset
,
batch_chunks
,
dataset
.
shuffle
,
def
dataloader
(
dataset
,
batch_chunks
,
shuffle
=
True
,
**
kwargs
):
return
DataLoader
(
dataset
,
batch_chunks
,
shuffle
,
collate_fn
=
_collate
,
**
kwargs
)
torchani/energyshifter.py
View file @
10699bf7
...
...
@@ -67,3 +67,10 @@ class EnergyShifter:
for
i
in
species
:
s
+=
self
.
self_energies
[
i
]
return
energies
+
s
def
dataset_subtract_sae
(
self
,
data
):
"""Allow object of this class to be used as transforms of pytorch's
dataset.
"""
data
[
'energies'
]
=
self
.
subtract_sae
(
data
[
'energies'
],
data
[
'species'
])
return
data
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment