Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
ffb075e6
Commit
ffb075e6
authored
Nov 07, 2019
by
Gao, Xiang
Committed by
Farhad Ramezanghorbani
Nov 07, 2019
Browse files
Simplify triple_by_molecule (#368)
* Simplify triple_by_molecule * fix * fix * fix
parent
89ff3b46
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
131 deletions
+10
-131
tools/training-benchmark-nsys-profile.py
tools/training-benchmark-nsys-profile.py
+0
-1
tools/training-benchmark-with-aevcache.py
tools/training-benchmark-with-aevcache.py
+0
-95
tools/training-benchmark.py
tools/training-benchmark.py
+0
-1
torchani/aev.py
torchani/aev.py
+10
-34
No files found.
tools/training-benchmark-nsys-profile.py
View file @
ffb075e6
...
@@ -43,7 +43,6 @@ def enable_timers(model):
...
@@ -43,7 +43,6 @@ def enable_timers(model):
torchani
.
aev
.
compute_shifts
=
time_func
(
'compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
compute_shifts
=
time_func
(
'compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
triu_index
=
time_func
(
'triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
triu_index
=
time_func
(
'triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
convert_pair_index
=
time_func
(
'convert_pair_index'
,
torchani
.
aev
.
convert_pair_index
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
compute_aev
=
time_func
(
'compute_aev'
,
torchani
.
aev
.
compute_aev
)
torchani
.
aev
.
compute_aev
=
time_func
(
'compute_aev'
,
torchani
.
aev
.
compute_aev
)
...
...
tools/training-benchmark-with-aevcache.py
deleted
100644 → 0
View file @
89ff3b46
import
torch
import
ignite
import
torchani
import
timeit
import
tqdm
import
argparse
# parse command line arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'cache_path'
,
help
=
'Path of the aev cache'
)
parser
.
add_argument
(
'-d'
,
'--device'
,
help
=
'Device of modules and tensors'
,
default
=
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
))
parser
=
parser
.
parse_args
()
# set up benchmark
device
=
torch
.
device
(
parser
.
device
)
ani1x
=
torchani
.
models
.
ANI1x
()
consts
=
ani1x
.
consts
aev_computer
=
ani1x
.
aev_computer
shift_energy
=
ani1x
.
energy_shifter
def
atomic
():
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
384
,
128
),
torch
.
nn
.
CELU
(
0.1
),
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
CELU
(
0.1
),
torch
.
nn
.
Linear
(
128
,
64
),
torch
.
nn
.
CELU
(
0.1
),
torch
.
nn
.
Linear
(
64
,
1
)
)
return
model
model
=
torchani
.
ANIModel
([
atomic
()
for
_
in
range
(
4
)])
class
Flatten
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[
0
],
x
[
1
].
flatten
()
nnp
=
torch
.
nn
.
Sequential
(
model
,
Flatten
()).
to
(
device
)
dataset
=
torchani
.
data
.
AEVCacheLoader
(
parser
.
cache_path
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
def
init_tqdm
(
trainer
):
trainer
.
state
.
tqdm
=
tqdm
.
tqdm
(
total
=
len
(
dataset
),
desc
=
'epoch'
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
ITERATION_COMPLETED
)
def
update_tqdm
(
trainer
):
trainer
.
state
.
tqdm
.
update
(
1
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_COMPLETED
)
def
finalize_tqdm
(
trainer
):
trainer
.
state
.
tqdm
.
close
()
timers
=
{}
def
time_func
(
key
,
func
):
timers
[
key
]
=
0
def
wrapper
(
*
args
,
**
kwargs
):
start
=
timeit
.
default_timer
()
ret
=
func
(
*
args
,
**
kwargs
)
end
=
timeit
.
default_timer
()
timers
[
key
]
+=
end
-
start
return
ret
return
wrapper
# enable timers
nnp
[
0
].
forward
=
time_func
(
'forward'
,
nnp
[
0
].
forward
)
# run it!
start
=
timeit
.
default_timer
()
trainer
.
run
(
dataset
,
max_epochs
=
1
)
elapsed
=
round
(
timeit
.
default_timer
()
-
start
,
2
)
print
(
'NN:'
,
timers
[
'forward'
])
print
(
'Epoch time:'
,
elapsed
)
tools/training-benchmark.py
View file @
ffb075e6
...
@@ -93,7 +93,6 @@ if __name__ == "__main__":
...
@@ -93,7 +93,6 @@ if __name__ == "__main__":
torchani
.
aev
.
compute_shifts
=
time_func
(
'torchani.aev.compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
compute_shifts
=
time_func
(
'torchani.aev.compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'torchani.aev.neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'torchani.aev.neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
triu_index
=
time_func
(
'torchani.aev.triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
triu_index
=
time_func
(
'torchani.aev.triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
convert_pair_index
=
time_func
(
'torchani.aev.convert_pair_index'
,
torchani
.
aev
.
convert_pair_index
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'torchani.aev.cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'torchani.aev.cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'torchani.aev.triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'torchani.aev.triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
compute_aev
=
time_func
(
'torchani.aev.compute_aev'
,
torchani
.
aev
.
compute_aev
)
torchani
.
aev
.
compute_aev
=
time_func
(
'torchani.aev.compute_aev'
,
torchani
.
aev
.
compute_aev
)
...
...
torchani/aev.py
View file @
ffb075e6
...
@@ -174,31 +174,6 @@ def triu_index(num_species):
...
@@ -174,31 +174,6 @@ def triu_index(num_species):
return
ret
return
ret
def
convert_pair_index
(
index
):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
elem2: 1 2 2 3 3 3 4 4 4 4 ...
This function convert index back to elem1 and elem2
To implement this, divide it into groups, the first group contains 1
elements, the second contains 2 elements, ..., the nth group contains
n elements.
Let's say we want to compute the elem1 and elem2 for index i. We first find
the number of complete groups contained in index 0, 1, ..., i - 1
(all inclusive, not including i), then i will be in the next group. Let's
say there are N complete groups, then these N groups contains
N * (N + 1) / 2 elements, solving for the largest N that satisfies
N * (N + 1) / 2 <= i, will get the N we want.
"""
n
=
(
torch
.
sqrt
(
1.0
+
8.0
*
index
.
to
(
torch
.
float
))
-
1.0
)
/
2.0
n
=
torch
.
floor
(
n
).
to
(
torch
.
long
)
num_elems
=
n
*
(
n
+
1
)
/
2
return
index
-
num_elems
,
n
+
1
def
cumsum_from_zero
(
input_
):
def
cumsum_from_zero
(
input_
):
# type: (torch.Tensor) -> torch.Tensor
# type: (torch.Tensor) -> torch.Tensor
cumsum
=
torch
.
cumsum
(
input_
,
dim
=
0
)
cumsum
=
torch
.
cumsum
(
input_
,
dim
=
0
)
...
@@ -219,7 +194,6 @@ def triple_by_molecule(atom_index1, atom_index2):
...
@@ -219,7 +194,6 @@ def triple_by_molecule(atom_index1, atom_index2):
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
"""
# convert representation from pair to central-others
# convert representation from pair to central-others
n
=
atom_index1
.
shape
[
0
]
ai1
=
torch
.
cat
([
atom_index1
,
atom_index2
])
ai1
=
torch
.
cat
([
atom_index1
,
atom_index2
])
sorted_ai1
,
rev_indices
=
ai1
.
sort
()
sorted_ai1
,
rev_indices
=
ai1
.
sort
()
...
@@ -228,17 +202,18 @@ def triple_by_molecule(atom_index1, atom_index2):
...
@@ -228,17 +202,18 @@ def triple_by_molecule(atom_index1, atom_index2):
uniqued_central_atom_index
=
unique_results
[
0
]
uniqued_central_atom_index
=
unique_results
[
0
]
counts
=
unique_results
[
-
1
]
counts
=
unique_results
[
-
1
]
#
do local combinations within unique key, assuming sorted
#
compute central_atom_index
pair_sizes
=
(
counts
*
(
counts
-
1
)
/
2
).
long
()
pair_sizes
=
(
counts
*
(
counts
-
1
)
/
2
).
long
()
total_size
=
pair_sizes
.
sum
()
pair_indices
=
torch
.
repeat_interleave
(
pair_sizes
)
pair_indices
=
torch
.
repeat_interleave
(
pair_sizes
)
central_atom_index
=
uniqued_central_atom_index
.
index_select
(
0
,
pair_indices
)
central_atom_index
=
uniqued_central_atom_index
.
index_select
(
0
,
pair_indices
)
cumsum
=
cumsum_from_zero
(
pair_sizes
)
cumsum
=
cumsum
.
index_select
(
0
,
pair_indices
)
# do local combinations within unique key, assuming sorted
sorted_local_pair_index
=
torch
.
arange
(
total_size
,
device
=
cumsum
.
device
,
dtype
=
torch
.
long
)
-
cumsum
m
=
counts
.
max
().
item
()
if
counts
.
numel
()
>
0
else
0
sorted_local_index1
,
sorted_local_index2
=
convert_pair_index
(
sorted_local_pair_index
)
n
=
pair_sizes
.
shape
[
0
]
cumsum
=
cumsum_from_zero
(
counts
)
intra_pair_indices
=
torch
.
tril_indices
(
m
,
m
,
-
1
,
device
=
ai1
.
device
).
t
().
unsqueeze
(
0
).
expand
(
n
,
-
1
,
-
1
)
cumsum
=
cumsum
.
index_select
(
0
,
pair_indices
)
mask
=
(
torch
.
arange
(
intra_pair_indices
.
shape
[
1
],
device
=
ai1
.
device
)
<
pair_sizes
.
unsqueeze
(
1
)).
flatten
()
sorted_local_index1
,
sorted_local_index2
=
intra_pair_indices
.
flatten
(
0
,
1
)[
mask
,
:].
unbind
(
-
1
)
cumsum
=
cumsum_from_zero
(
counts
).
index_select
(
0
,
pair_indices
)
sorted_local_index1
+=
cumsum
sorted_local_index1
+=
cumsum
sorted_local_index2
+=
cumsum
sorted_local_index2
+=
cumsum
...
@@ -247,6 +222,7 @@ def triple_by_molecule(atom_index1, atom_index2):
...
@@ -247,6 +222,7 @@ def triple_by_molecule(atom_index1, atom_index2):
local_index2
=
rev_indices
[
sorted_local_index2
]
local_index2
=
rev_indices
[
sorted_local_index2
]
# compute mapping between representation of central-other to pair
# compute mapping between representation of central-other to pair
n
=
atom_index1
.
shape
[
0
]
sign1
=
((
local_index1
<
n
).
to
(
torch
.
long
)
*
2
)
-
1
sign1
=
((
local_index1
<
n
).
to
(
torch
.
long
)
*
2
)
-
1
sign2
=
((
local_index2
<
n
).
to
(
torch
.
long
)
*
2
)
-
1
sign2
=
((
local_index2
<
n
).
to
(
torch
.
long
)
*
2
)
-
1
return
central_atom_index
,
local_index1
%
n
,
local_index2
%
n
,
sign1
,
sign2
return
central_atom_index
,
local_index1
%
n
,
local_index2
%
n
,
sign1
,
sign2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment