Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
99c2c3dc
Unverified
Commit
99c2c3dc
authored
Aug 17, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 17, 2018
Browse files
Split batch to avoid performance penalty on padding (#66)
parent
22975fa7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
138 additions
and
24 deletions
+138
-24
examples/training-benchmark.py
examples/training-benchmark.py
+1
-1
tests/test_data.py
tests/test_data.py
+43
-6
torchani/training/container.py
torchani/training/container.py
+10
-7
torchani/training/data.py
torchani/training/data.py
+84
-10
No files found.
examples/training-benchmark.py
View file @
99c2c3dc
...
@@ -16,7 +16,7 @@ parser.add_argument('-d', '--device',
...
@@ -16,7 +16,7 @@ parser.add_argument('-d', '--device',
default
=
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
))
default
=
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
))
parser
.
add_argument
(
'--batch_size'
,
parser
.
add_argument
(
'--batch_size'
,
help
=
'Number of conformations of each batch'
,
help
=
'Number of conformations of each batch'
,
default
=
256
,
type
=
int
)
default
=
1024
,
type
=
int
)
parser
=
parser
.
parse_args
()
parser
=
parser
.
parse_args
()
# set up benchmark
# set up benchmark
...
...
tests/test_data.py
View file @
99c2c3dc
import
os
import
os
import
torch
import
torchani
import
torchani
import
unittest
import
unittest
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
dataset_path
=
os
.
path
.
join
(
path
,
'../dataset'
)
dataset_path
=
os
.
path
.
join
(
path
,
'../dataset'
)
print
(
dataset_path
)
batch_size
=
256
batch_size
=
256
aev
=
torchani
.
AEVComputer
()
aev
=
torchani
.
AEVComputer
()
...
@@ -16,10 +16,47 @@ class TestData(unittest.TestCase):
...
@@ -16,10 +16,47 @@ class TestData(unittest.TestCase):
aev
.
species
,
aev
.
species
,
batch_size
)
batch_size
)
def
_assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
((
t1
-
t2
).
abs
().
max
(),
0
)
def
testSplitBatch
(
self
):
species1
=
torch
.
randint
(
4
,
(
5
,
4
),
dtype
=
torch
.
long
)
coordinates1
=
torch
.
randn
(
5
,
4
,
3
)
species2
=
torch
.
randint
(
4
,
(
2
,
8
),
dtype
=
torch
.
long
)
coordinates2
=
torch
.
randn
(
2
,
8
,
3
)
species3
=
torch
.
randint
(
4
,
(
10
,
20
),
dtype
=
torch
.
long
)
coordinates3
=
torch
.
randn
(
10
,
20
,
3
)
species
,
coordinates
=
torchani
.
padding
.
pad_and_batch
([
(
species1
,
coordinates1
),
(
species2
,
coordinates2
),
(
species3
,
coordinates3
),
])
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
chunks
=
torchani
.
training
.
data
.
split_batch
(
natoms
,
species
,
coordinates
)
start
=
0
last
=
None
for
s
,
c
in
chunks
:
n
=
(
s
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
if
last
is
not
None
:
self
.
assertNotEqual
(
last
[
-
1
],
n
[
0
])
conformations
=
s
.
shape
[
0
]
self
.
assertGreater
(
conformations
,
0
)
s_
=
species
[
start
:
start
+
conformations
,
...]
c_
=
coordinates
[
start
:
start
+
conformations
,
...]
s_
,
c_
=
torchani
.
padding
.
strip_redundant_padding
(
s_
,
c_
)
self
.
_assertTensorEqual
(
s
,
s_
)
self
.
_assertTensorEqual
(
c
,
c_
)
start
+=
conformations
s
,
c
=
torchani
.
padding
.
pad_and_batch
(
chunks
)
self
.
_assertTensorEqual
(
s
,
species
)
self
.
_assertTensorEqual
(
c
,
coordinates
)
def
testTensorShape
(
self
):
def
testTensorShape
(
self
):
for
i
in
self
.
ds
:
for
i
in
self
.
ds
:
input
,
output
=
i
input
,
output
=
i
species
,
coordinates
=
input
species
,
coordinates
=
torchani
.
padding
.
pad_and_batch
(
input
)
energies
=
output
[
'energies'
]
energies
=
output
[
'energies'
]
self
.
assertEqual
(
len
(
species
.
shape
),
2
)
self
.
assertEqual
(
len
(
species
.
shape
),
2
)
self
.
assertLessEqual
(
species
.
shape
[
0
],
batch_size
)
self
.
assertLessEqual
(
species
.
shape
[
0
],
batch_size
)
...
@@ -32,10 +69,10 @@ class TestData(unittest.TestCase):
...
@@ -32,10 +69,10 @@ class TestData(unittest.TestCase):
def
testNoUnnecessaryPadding
(
self
):
def
testNoUnnecessaryPadding
(
self
):
for
i
in
self
.
ds
:
for
i
in
self
.
ds
:
input
,
_
=
i
for
input
in
i
[
0
]:
species
,
_
=
input
species
,
_
=
input
non_padding
=
(
species
>=
0
)[:,
-
1
].
nonzero
()
non_padding
=
(
species
>=
0
)[:,
-
1
].
nonzero
()
self
.
assertGreater
(
non_padding
.
numel
(),
0
)
self
.
assertGreater
(
non_padding
.
numel
(),
0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
torchani/training/container.py
View file @
99c2c3dc
import
torch
import
torch
from
..
import
padding
class
Container
(
torch
.
nn
.
Module
):
class
Container
(
torch
.
nn
.
Module
):
...
@@ -10,12 +11,14 @@ class Container(torch.nn.Module):
...
@@ -10,12 +11,14 @@ class Container(torch.nn.Module):
setattr
(
self
,
'model_'
+
i
,
models
[
i
])
setattr
(
self
,
'model_'
+
i
,
models
[
i
])
def
forward
(
self
,
species_coordinates
):
def
forward
(
self
,
species_coordinates
):
species
,
coordinates
=
species_coordinates
results
=
{
k
:
[]
for
k
in
self
.
keys
}
results
=
{
for
sc
in
species_coordinates
:
'species'
:
species
,
for
k
in
self
.
keys
:
'coordinates'
:
coordinates
,
model
=
getattr
(
self
,
'model_'
+
k
)
}
_
,
result
=
model
(
sc
)
results
[
k
].
append
(
result
)
results
[
'species'
],
results
[
'coordinates'
]
=
\
padding
.
pad_and_batch
(
species_coordinates
)
for
k
in
self
.
keys
:
for
k
in
self
.
keys
:
model
=
getattr
(
self
,
'model_'
+
k
)
results
[
k
]
=
torch
.
cat
(
results
[
k
])
_
,
results
[
k
]
=
model
((
species
,
coordinates
))
return
results
return
results
torchani/training/data.py
View file @
99c2c3dc
...
@@ -8,6 +8,73 @@ import pickle
...
@@ -8,6 +8,73 @@ import pickle
from
..
import
padding
from
..
import
padding
def
chunk_counts
(
counts
,
split
):
split
=
[
x
+
1
for
x
in
split
]
+
[
None
]
count_chunks
=
[]
start
=
0
for
i
in
split
:
count_chunks
.
append
(
counts
[
start
:
i
])
start
=
i
chunk_conformations
=
[
sum
([
y
[
1
]
for
y
in
x
])
for
x
in
count_chunks
]
chunk_maxatoms
=
[
x
[
-
1
][
0
]
for
x
in
count_chunks
]
return
chunk_conformations
,
chunk_maxatoms
def
split_cost
(
counts
,
split
):
split_min_cost
=
40000
cost
=
0
chunk_conformations
,
chunk_maxatoms
=
chunk_counts
(
counts
,
split
)
for
conformations
,
maxatoms
in
zip
(
chunk_conformations
,
chunk_maxatoms
):
cost
+=
max
(
conformations
*
maxatoms
**
2
,
split_min_cost
)
return
cost
def
split_batch
(
natoms
,
species
,
coordinates
):
# count number of conformation by natoms
natoms
=
natoms
.
tolist
()
counts
=
[]
for
i
in
natoms
:
if
len
(
counts
)
==
0
:
counts
.
append
([
i
,
1
])
continue
if
i
==
counts
[
-
1
][
0
]:
counts
[
-
1
][
1
]
+=
1
else
:
counts
.
append
([
i
,
1
])
# find best split using greedy strategy
split
=
[]
cost
=
split_cost
(
counts
,
split
)
improved
=
True
while
improved
:
improved
=
False
cycle_split
=
split
cycle_cost
=
cost
for
i
in
range
(
len
(
counts
)
-
1
):
if
i
not
in
split
:
s
=
sorted
(
split
+
[
i
])
c
=
split_cost
(
counts
,
s
)
if
c
<
cycle_cost
:
improved
=
True
cycle_cost
=
c
cycle_split
=
s
if
improved
:
split
=
cycle_split
cost
=
cycle_cost
# do split
start
=
0
species_coordinates
=
[]
chunk_conformations
,
_
=
chunk_counts
(
counts
,
split
)
for
i
in
chunk_conformations
:
s
=
species
end
=
start
+
i
s
=
species
[
start
:
end
,
...]
c
=
coordinates
[
start
:
end
,
...]
s
,
c
=
padding
.
strip_redundant_padding
(
s
,
c
)
species_coordinates
.
append
((
s
,
c
))
start
=
end
return
species_coordinates
class
BatchedANIDataset
(
Dataset
):
class
BatchedANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
species
,
batch_size
,
shuffle
=
True
,
def
__init__
(
self
,
path
,
species
,
batch_size
,
shuffle
=
True
,
...
@@ -71,29 +138,36 @@ class BatchedANIDataset(Dataset):
...
@@ -71,29 +138,36 @@ class BatchedANIDataset(Dataset):
properties
)
properties
)
# split into minibatches, and strip reduncant padding
# split into minibatches, and strip reduncant padding
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
batches
=
[]
batches
=
[]
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
for
i
in
range
(
num_batches
):
for
i
in
range
(
num_batches
):
start
=
i
*
batch_size
start
=
i
*
batch_size
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
species_batch
=
species
[
start
:
end
,
...]
natoms_batch
=
natoms
[
start
:
end
]
coordinates_batch
=
coordinates
[
start
:
end
,
...]
natoms_batch
,
indices
=
natoms_batch
.
sort
()
species_batch
=
species
[
start
:
end
,
...].
index_select
(
0
,
indices
)
coordinates_batch
=
coordinates
[
start
:
end
,
...]
\
.
index_select
(
0
,
indices
)
properties_batch
=
{
properties_batch
=
{
k
:
properties
[
k
][
start
:
end
,
...]
for
k
in
properties
k
:
properties
[
k
][
start
:
end
,
...].
index_select
(
0
,
indices
)
for
k
in
properties
}
}
batches
.
append
((
padding
.
strip_redundant_padding
(
species_batch
,
# further split batch into chunks
coordinates_batch
),
species_coordinates
=
split_batch
(
natoms_batch
,
species_batch
,
properties_batch
))
coordinates_batch
)
batch
=
species_coordinates
,
properties_batch
batches
.
append
(
batch
)
self
.
batches
=
batches
self
.
batches
=
batches
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
(
species
,
coordinates
)
,
properties
=
self
.
batches
[
idx
]
species
_
coordinates
,
properties
=
self
.
batches
[
idx
]
species
=
species
.
to
(
self
.
device
)
species
_coordinates
=
[(
s
.
to
(
self
.
device
),
c
.
to
(
self
.
device
)
)
coordinates
=
coordinates
.
to
(
self
.
device
)
for
s
,
c
in
species_coordinates
]
properties
=
{
properties
=
{
k
:
properties
[
k
].
to
(
self
.
device
)
for
k
in
properties
k
:
properties
[
k
].
to
(
self
.
device
)
for
k
in
properties
}
}
return
(
species
,
coordinates
)
,
properties
return
species
_
coordinates
,
properties
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
batches
)
return
len
(
self
.
batches
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment