Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
ffb075e6
You need to sign in or sign up before continuing.
Commit
ffb075e6
authored
Nov 07, 2019
by
Gao, Xiang
Committed by
Farhad Ramezanghorbani
Nov 07, 2019
Browse files
Simplify triple_by_molecule (#368)
* Simplify triple_by_molecule * fix * fix * fix
parent
89ff3b46
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
131 deletions
+10
-131
tools/training-benchmark-nsys-profile.py
tools/training-benchmark-nsys-profile.py
+0
-1
tools/training-benchmark-with-aevcache.py
tools/training-benchmark-with-aevcache.py
+0
-95
tools/training-benchmark.py
tools/training-benchmark.py
+0
-1
torchani/aev.py
torchani/aev.py
+10
-34
No files found.
tools/training-benchmark-nsys-profile.py
View file @
ffb075e6
...
@@ -43,7 +43,6 @@ def enable_timers(model):
...
@@ -43,7 +43,6 @@ def enable_timers(model):
torchani
.
aev
.
compute_shifts
=
time_func
(
'compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
compute_shifts
=
time_func
(
'compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
triu_index
=
time_func
(
'triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
triu_index
=
time_func
(
'triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
convert_pair_index
=
time_func
(
'convert_pair_index'
,
torchani
.
aev
.
convert_pair_index
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
compute_aev
=
time_func
(
'compute_aev'
,
torchani
.
aev
.
compute_aev
)
torchani
.
aev
.
compute_aev
=
time_func
(
'compute_aev'
,
torchani
.
aev
.
compute_aev
)
...
...
tools/training-benchmark-with-aevcache.py
deleted
100644 → 0
View file @
89ff3b46
import
torch
import
ignite
import
torchani
import
timeit
import
tqdm
import
argparse
# parse command line arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'cache_path'
,
help
=
'Path of the aev cache'
)
parser
.
add_argument
(
'-d'
,
'--device'
,
help
=
'Device of modules and tensors'
,
default
=
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
))
parser
=
parser
.
parse_args
()
# set up benchmark
device
=
torch
.
device
(
parser
.
device
)
ani1x
=
torchani
.
models
.
ANI1x
()
consts
=
ani1x
.
consts
aev_computer
=
ani1x
.
aev_computer
shift_energy
=
ani1x
.
energy_shifter
def
atomic
():
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
384
,
128
),
torch
.
nn
.
CELU
(
0.1
),
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
CELU
(
0.1
),
torch
.
nn
.
Linear
(
128
,
64
),
torch
.
nn
.
CELU
(
0.1
),
torch
.
nn
.
Linear
(
64
,
1
)
)
return
model
model
=
torchani
.
ANIModel
([
atomic
()
for
_
in
range
(
4
)])
class
Flatten
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[
0
],
x
[
1
].
flatten
()
nnp
=
torch
.
nn
.
Sequential
(
model
,
Flatten
()).
to
(
device
)
dataset
=
torchani
.
data
.
AEVCacheLoader
(
parser
.
cache_path
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
def
init_tqdm
(
trainer
):
trainer
.
state
.
tqdm
=
tqdm
.
tqdm
(
total
=
len
(
dataset
),
desc
=
'epoch'
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
ITERATION_COMPLETED
)
def
update_tqdm
(
trainer
):
trainer
.
state
.
tqdm
.
update
(
1
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_COMPLETED
)
def
finalize_tqdm
(
trainer
):
trainer
.
state
.
tqdm
.
close
()
timers
=
{}
def
time_func
(
key
,
func
):
timers
[
key
]
=
0
def
wrapper
(
*
args
,
**
kwargs
):
start
=
timeit
.
default_timer
()
ret
=
func
(
*
args
,
**
kwargs
)
end
=
timeit
.
default_timer
()
timers
[
key
]
+=
end
-
start
return
ret
return
wrapper
# enable timers
nnp
[
0
].
forward
=
time_func
(
'forward'
,
nnp
[
0
].
forward
)
# run it!
start
=
timeit
.
default_timer
()
trainer
.
run
(
dataset
,
max_epochs
=
1
)
elapsed
=
round
(
timeit
.
default_timer
()
-
start
,
2
)
print
(
'NN:'
,
timers
[
'forward'
])
print
(
'Epoch time:'
,
elapsed
)
tools/training-benchmark.py
View file @
ffb075e6
...
@@ -93,7 +93,6 @@ if __name__ == "__main__":
...
@@ -93,7 +93,6 @@ if __name__ == "__main__":
torchani
.
aev
.
compute_shifts
=
time_func
(
'torchani.aev.compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
compute_shifts
=
time_func
(
'torchani.aev.compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'torchani.aev.neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'torchani.aev.neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
triu_index
=
time_func
(
'torchani.aev.triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
triu_index
=
time_func
(
'torchani.aev.triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
convert_pair_index
=
time_func
(
'torchani.aev.convert_pair_index'
,
torchani
.
aev
.
convert_pair_index
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'torchani.aev.cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'torchani.aev.cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'torchani.aev.triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'torchani.aev.triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
compute_aev
=
time_func
(
'torchani.aev.compute_aev'
,
torchani
.
aev
.
compute_aev
)
torchani
.
aev
.
compute_aev
=
time_func
(
'torchani.aev.compute_aev'
,
torchani
.
aev
.
compute_aev
)
...
...
torchani/aev.py
View file @
ffb075e6
...
@@ -174,31 +174,6 @@ def triu_index(num_species):
...
@@ -174,31 +174,6 @@ def triu_index(num_species):
return
ret
return
ret
def
convert_pair_index
(
index
):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
elem2: 1 2 2 3 3 3 4 4 4 4 ...
This function convert index back to elem1 and elem2
To implement this, divide it into groups, the first group contains 1
elements, the second contains 2 elements, ..., the nth group contains
n elements.
Let's say we want to compute the elem1 and elem2 for index i. We first find
the number of complete groups contained in index 0, 1, ..., i - 1
(all inclusive, not including i), then i will be in the next group. Let's
say there are N complete groups, then these N groups contains
N * (N + 1) / 2 elements, solving for the largest N that satisfies
N * (N + 1) / 2 <= i, will get the N we want.
"""
n
=
(
torch
.
sqrt
(
1.0
+
8.0
*
index
.
to
(
torch
.
float
))
-
1.0
)
/
2.0
n
=
torch
.
floor
(
n
).
to
(
torch
.
long
)
num_elems
=
n
*
(
n
+
1
)
/
2
return
index
-
num_elems
,
n
+
1
def
cumsum_from_zero
(
input_
):
def
cumsum_from_zero
(
input_
):
# type: (torch.Tensor) -> torch.Tensor
# type: (torch.Tensor) -> torch.Tensor
cumsum
=
torch
.
cumsum
(
input_
,
dim
=
0
)
cumsum
=
torch
.
cumsum
(
input_
,
dim
=
0
)
...
@@ -219,7 +194,6 @@ def triple_by_molecule(atom_index1, atom_index2):
...
@@ -219,7 +194,6 @@ def triple_by_molecule(atom_index1, atom_index2):
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
"""
# convert representation from pair to central-others
# convert representation from pair to central-others
n
=
atom_index1
.
shape
[
0
]
ai1
=
torch
.
cat
([
atom_index1
,
atom_index2
])
ai1
=
torch
.
cat
([
atom_index1
,
atom_index2
])
sorted_ai1
,
rev_indices
=
ai1
.
sort
()
sorted_ai1
,
rev_indices
=
ai1
.
sort
()
...
@@ -228,17 +202,18 @@ def triple_by_molecule(atom_index1, atom_index2):
...
@@ -228,17 +202,18 @@ def triple_by_molecule(atom_index1, atom_index2):
uniqued_central_atom_index
=
unique_results
[
0
]
uniqued_central_atom_index
=
unique_results
[
0
]
counts
=
unique_results
[
-
1
]
counts
=
unique_results
[
-
1
]
#
do local combinations within unique key, assuming sorted
#
compute central_atom_index
pair_sizes
=
(
counts
*
(
counts
-
1
)
/
2
).
long
()
pair_sizes
=
(
counts
*
(
counts
-
1
)
/
2
).
long
()
total_size
=
pair_sizes
.
sum
()
pair_indices
=
torch
.
repeat_interleave
(
pair_sizes
)
pair_indices
=
torch
.
repeat_interleave
(
pair_sizes
)
central_atom_index
=
uniqued_central_atom_index
.
index_select
(
0
,
pair_indices
)
central_atom_index
=
uniqued_central_atom_index
.
index_select
(
0
,
pair_indices
)
cumsum
=
cumsum_from_zero
(
pair_sizes
)
cumsum
=
cumsum
.
index_select
(
0
,
pair_indices
)
# do local combinations within unique key, assuming sorted
sorted_local_pair_index
=
torch
.
arange
(
total_size
,
device
=
cumsum
.
device
,
dtype
=
torch
.
long
)
-
cumsum
m
=
counts
.
max
().
item
()
if
counts
.
numel
()
>
0
else
0
sorted_local_index1
,
sorted_local_index2
=
convert_pair_index
(
sorted_local_pair_index
)
n
=
pair_sizes
.
shape
[
0
]
cumsum
=
cumsum_from_zero
(
counts
)
intra_pair_indices
=
torch
.
tril_indices
(
m
,
m
,
-
1
,
device
=
ai1
.
device
).
t
().
unsqueeze
(
0
).
expand
(
n
,
-
1
,
-
1
)
cumsum
=
cumsum
.
index_select
(
0
,
pair_indices
)
mask
=
(
torch
.
arange
(
intra_pair_indices
.
shape
[
1
],
device
=
ai1
.
device
)
<
pair_sizes
.
unsqueeze
(
1
)).
flatten
()
sorted_local_index1
,
sorted_local_index2
=
intra_pair_indices
.
flatten
(
0
,
1
)[
mask
,
:].
unbind
(
-
1
)
cumsum
=
cumsum_from_zero
(
counts
).
index_select
(
0
,
pair_indices
)
sorted_local_index1
+=
cumsum
sorted_local_index1
+=
cumsum
sorted_local_index2
+=
cumsum
sorted_local_index2
+=
cumsum
...
@@ -247,6 +222,7 @@ def triple_by_molecule(atom_index1, atom_index2):
...
@@ -247,6 +222,7 @@ def triple_by_molecule(atom_index1, atom_index2):
local_index2
=
rev_indices
[
sorted_local_index2
]
local_index2
=
rev_indices
[
sorted_local_index2
]
# compute mapping between representation of central-other to pair
# compute mapping between representation of central-other to pair
n
=
atom_index1
.
shape
[
0
]
sign1
=
((
local_index1
<
n
).
to
(
torch
.
long
)
*
2
)
-
1
sign1
=
((
local_index1
<
n
).
to
(
torch
.
long
)
*
2
)
-
1
sign2
=
((
local_index2
<
n
).
to
(
torch
.
long
)
*
2
)
-
1
sign2
=
((
local_index2
<
n
).
to
(
torch
.
long
)
*
2
)
-
1
return
central_atom_index
,
local_index1
%
n
,
local_index2
%
n
,
sign1
,
sign2
return
central_atom_index
,
local_index1
%
n
,
local_index2
%
n
,
sign1
,
sign2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment