Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
99c2c3dc
Unverified
Commit
99c2c3dc
authored
Aug 17, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 17, 2018
Browse files
Split batch to avoid performance penalty on padding (#66)
parent
22975fa7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
138 additions
and
24 deletions
+138
-24
examples/training-benchmark.py
examples/training-benchmark.py
+1
-1
tests/test_data.py
tests/test_data.py
+43
-6
torchani/training/container.py
torchani/training/container.py
+10
-7
torchani/training/data.py
torchani/training/data.py
+84
-10
No files found.
examples/training-benchmark.py
View file @
99c2c3dc
...
...
@@ -16,7 +16,7 @@ parser.add_argument('-d', '--device',
default
=
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
))
parser
.
add_argument
(
'--batch_size'
,
help
=
'Number of conformations of each batch'
,
default
=
256
,
type
=
int
)
default
=
1024
,
type
=
int
)
parser
=
parser
.
parse_args
()
# set up benchmark
...
...
tests/test_data.py
View file @
99c2c3dc
import
os
import
torch
import
torchani
import
unittest
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
dataset_path
=
os
.
path
.
join
(
path
,
'../dataset'
)
print
(
dataset_path
)
batch_size
=
256
aev
=
torchani
.
AEVComputer
()
...
...
@@ -16,10 +16,47 @@ class TestData(unittest.TestCase):
aev
.
species
,
batch_size
)
def
_assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
((
t1
-
t2
).
abs
().
max
(),
0
)
def
testSplitBatch
(
self
):
species1
=
torch
.
randint
(
4
,
(
5
,
4
),
dtype
=
torch
.
long
)
coordinates1
=
torch
.
randn
(
5
,
4
,
3
)
species2
=
torch
.
randint
(
4
,
(
2
,
8
),
dtype
=
torch
.
long
)
coordinates2
=
torch
.
randn
(
2
,
8
,
3
)
species3
=
torch
.
randint
(
4
,
(
10
,
20
),
dtype
=
torch
.
long
)
coordinates3
=
torch
.
randn
(
10
,
20
,
3
)
species
,
coordinates
=
torchani
.
padding
.
pad_and_batch
([
(
species1
,
coordinates1
),
(
species2
,
coordinates2
),
(
species3
,
coordinates3
),
])
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
chunks
=
torchani
.
training
.
data
.
split_batch
(
natoms
,
species
,
coordinates
)
start
=
0
last
=
None
for
s
,
c
in
chunks
:
n
=
(
s
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
if
last
is
not
None
:
self
.
assertNotEqual
(
last
[
-
1
],
n
[
0
])
conformations
=
s
.
shape
[
0
]
self
.
assertGreater
(
conformations
,
0
)
s_
=
species
[
start
:
start
+
conformations
,
...]
c_
=
coordinates
[
start
:
start
+
conformations
,
...]
s_
,
c_
=
torchani
.
padding
.
strip_redundant_padding
(
s_
,
c_
)
self
.
_assertTensorEqual
(
s
,
s_
)
self
.
_assertTensorEqual
(
c
,
c_
)
start
+=
conformations
s
,
c
=
torchani
.
padding
.
pad_and_batch
(
chunks
)
self
.
_assertTensorEqual
(
s
,
species
)
self
.
_assertTensorEqual
(
c
,
coordinates
)
def
testTensorShape
(
self
):
for
i
in
self
.
ds
:
input
,
output
=
i
species
,
coordinates
=
input
species
,
coordinates
=
torchani
.
padding
.
pad_and_batch
(
input
)
energies
=
output
[
'energies'
]
self
.
assertEqual
(
len
(
species
.
shape
),
2
)
self
.
assertLessEqual
(
species
.
shape
[
0
],
batch_size
)
...
...
@@ -32,10 +69,10 @@ class TestData(unittest.TestCase):
def
testNoUnnecessaryPadding
(
self
):
for
i
in
self
.
ds
:
input
,
_
=
i
species
,
_
=
input
non_padding
=
(
species
>=
0
)[:,
-
1
].
nonzero
()
self
.
assertGreater
(
non_padding
.
numel
(),
0
)
for
input
in
i
[
0
]:
species
,
_
=
input
non_padding
=
(
species
>=
0
)[:,
-
1
].
nonzero
()
self
.
assertGreater
(
non_padding
.
numel
(),
0
)
if
__name__
==
'__main__'
:
...
...
torchani/training/container.py
View file @
99c2c3dc
import
torch
from
..
import
padding
class
Container
(
torch
.
nn
.
Module
):
...
...
@@ -10,12 +11,14 @@ class Container(torch.nn.Module):
setattr
(
self
,
'model_'
+
i
,
models
[
i
])
def
forward
(
self
,
species_coordinates
):
species
,
coordinates
=
species_coordinates
results
=
{
'species'
:
species
,
'coordinates'
:
coordinates
,
}
results
=
{
k
:
[]
for
k
in
self
.
keys
}
for
sc
in
species_coordinates
:
for
k
in
self
.
keys
:
model
=
getattr
(
self
,
'model_'
+
k
)
_
,
result
=
model
(
sc
)
results
[
k
].
append
(
result
)
results
[
'species'
],
results
[
'coordinates'
]
=
\
padding
.
pad_and_batch
(
species_coordinates
)
for
k
in
self
.
keys
:
model
=
getattr
(
self
,
'model_'
+
k
)
_
,
results
[
k
]
=
model
((
species
,
coordinates
))
results
[
k
]
=
torch
.
cat
(
results
[
k
])
return
results
torchani/training/data.py
View file @
99c2c3dc
...
...
@@ -8,6 +8,73 @@ import pickle
from
..
import
padding
def
chunk_counts
(
counts
,
split
):
split
=
[
x
+
1
for
x
in
split
]
+
[
None
]
count_chunks
=
[]
start
=
0
for
i
in
split
:
count_chunks
.
append
(
counts
[
start
:
i
])
start
=
i
chunk_conformations
=
[
sum
([
y
[
1
]
for
y
in
x
])
for
x
in
count_chunks
]
chunk_maxatoms
=
[
x
[
-
1
][
0
]
for
x
in
count_chunks
]
return
chunk_conformations
,
chunk_maxatoms
def
split_cost
(
counts
,
split
):
split_min_cost
=
40000
cost
=
0
chunk_conformations
,
chunk_maxatoms
=
chunk_counts
(
counts
,
split
)
for
conformations
,
maxatoms
in
zip
(
chunk_conformations
,
chunk_maxatoms
):
cost
+=
max
(
conformations
*
maxatoms
**
2
,
split_min_cost
)
return
cost
def
split_batch
(
natoms
,
species
,
coordinates
):
# count number of conformation by natoms
natoms
=
natoms
.
tolist
()
counts
=
[]
for
i
in
natoms
:
if
len
(
counts
)
==
0
:
counts
.
append
([
i
,
1
])
continue
if
i
==
counts
[
-
1
][
0
]:
counts
[
-
1
][
1
]
+=
1
else
:
counts
.
append
([
i
,
1
])
# find best split using greedy strategy
split
=
[]
cost
=
split_cost
(
counts
,
split
)
improved
=
True
while
improved
:
improved
=
False
cycle_split
=
split
cycle_cost
=
cost
for
i
in
range
(
len
(
counts
)
-
1
):
if
i
not
in
split
:
s
=
sorted
(
split
+
[
i
])
c
=
split_cost
(
counts
,
s
)
if
c
<
cycle_cost
:
improved
=
True
cycle_cost
=
c
cycle_split
=
s
if
improved
:
split
=
cycle_split
cost
=
cycle_cost
# do split
start
=
0
species_coordinates
=
[]
chunk_conformations
,
_
=
chunk_counts
(
counts
,
split
)
for
i
in
chunk_conformations
:
s
=
species
end
=
start
+
i
s
=
species
[
start
:
end
,
...]
c
=
coordinates
[
start
:
end
,
...]
s
,
c
=
padding
.
strip_redundant_padding
(
s
,
c
)
species_coordinates
.
append
((
s
,
c
))
start
=
end
return
species_coordinates
class
BatchedANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
species
,
batch_size
,
shuffle
=
True
,
...
...
@@ -71,29 +138,36 @@ class BatchedANIDataset(Dataset):
properties
)
# split into minibatches, and strip reduncant padding
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
batches
=
[]
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
for
i
in
range
(
num_batches
):
start
=
i
*
batch_size
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
species_batch
=
species
[
start
:
end
,
...]
coordinates_batch
=
coordinates
[
start
:
end
,
...]
natoms_batch
=
natoms
[
start
:
end
]
natoms_batch
,
indices
=
natoms_batch
.
sort
()
species_batch
=
species
[
start
:
end
,
...].
index_select
(
0
,
indices
)
coordinates_batch
=
coordinates
[
start
:
end
,
...]
\
.
index_select
(
0
,
indices
)
properties_batch
=
{
k
:
properties
[
k
][
start
:
end
,
...]
for
k
in
properties
k
:
properties
[
k
][
start
:
end
,
...].
index_select
(
0
,
indices
)
for
k
in
properties
}
batches
.
append
((
padding
.
strip_redundant_padding
(
species_batch
,
coordinates_batch
),
properties_batch
))
# further split batch into chunks
species_coordinates
=
split_batch
(
natoms_batch
,
species_batch
,
coordinates_batch
)
batch
=
species_coordinates
,
properties_batch
batches
.
append
(
batch
)
self
.
batches
=
batches
def
__getitem__
(
self
,
idx
):
(
species
,
coordinates
)
,
properties
=
self
.
batches
[
idx
]
species
=
species
.
to
(
self
.
device
)
coordinates
=
coordinates
.
to
(
self
.
device
)
species
_
coordinates
,
properties
=
self
.
batches
[
idx
]
species
_coordinates
=
[(
s
.
to
(
self
.
device
),
c
.
to
(
self
.
device
)
)
for
s
,
c
in
species_coordinates
]
properties
=
{
k
:
properties
[
k
].
to
(
self
.
device
)
for
k
in
properties
}
return
(
species
,
coordinates
)
,
properties
return
species
_
coordinates
,
properties
def
__len__
(
self
):
return
len
(
self
.
batches
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment