Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a9a12890
"vscode:/vscode.git/clone" did not exist on "635f1e94e855d7832363ecdb2ed70affe487608a"
Commit
a9a12890
authored
Aug 09, 2022
by
Gustaf Ahdritz
Browse files
Make tracing more granular
parent
e8b3789f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
182 additions
and
18 deletions
+182
-18
openfold/utils/chunk_utils.py
openfold/utils/chunk_utils.py
+1
-0
openfold/utils/trace_utils.py
openfold/utils/trace_utils.py
+181
-18
No files found.
openfold/utils/chunk_utils.py
View file @
a9a12890
...
@@ -358,6 +358,7 @@ class ChunkSizeTuner:
...
@@ -358,6 +358,7 @@ class ChunkSizeTuner:
candidates
=
[
2
**
l
for
l
in
range
(
int
(
math
.
log
(
self
.
max_chunk_size
,
2
))
+
1
)]
candidates
=
[
2
**
l
for
l
in
range
(
int
(
math
.
log
(
self
.
max_chunk_size
,
2
))
+
1
)]
candidates
=
[
c
for
c
in
candidates
if
c
>
min_chunk_size
]
candidates
=
[
c
for
c
in
candidates
if
c
>
min_chunk_size
]
candidates
=
[
min_chunk_size
]
+
candidates
candidates
=
[
min_chunk_size
]
+
candidates
candidates
[
-
1
]
+=
4
def
test_chunk_size
(
chunk_size
):
def
test_chunk_size
(
chunk_size
):
try
:
try
:
...
...
openfold/utils/trace_utils.py
View file @
a9a12890
...
@@ -71,15 +71,15 @@ def trace_model_(model, sample_input):
...
@@ -71,15 +71,15 @@ def trace_model_(model, sample_input):
seq_mask
=
feats
[
"seq_mask"
].
to
(
device
)
seq_mask
=
feats
[
"seq_mask"
].
to
(
device
)
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
msa_mask
=
feats
[
"msa_mask"
].
to
(
device
)
extra_msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
device
)
extra_msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
device
)
template_pair_mask
=
torch
.
stack
([
pair_mask
]
*
no_templates
,
dim
=-
3
)
template_pair_mask
=
torch
.
stack
([
pair_mask
]
*
no_templates
,
dim
=-
3
)
# Create some fake representations with the correct shapes
# Create some fake representations with the correct shapes
m
=
torch
.
rand
(
msa_depth
,
n
,
model
.
globals
.
c_m
).
to
(
device
)
m
=
torch
.
rand
(
msa_depth
+
4
,
n
,
model
.
globals
.
c_m
).
to
(
device
)
z
=
torch
.
rand
(
n
,
n
,
model
.
globals
.
c_z
).
to
(
device
)
z
=
torch
.
rand
(
n
,
n
,
model
.
globals
.
c_z
).
to
(
device
)
t
=
torch
.
rand
(
no_templates
,
n
,
n
,
model
.
globals
.
c_t
).
to
(
device
)
t
=
torch
.
rand
(
no_templates
,
n
,
n
,
model
.
globals
.
c_t
).
to
(
device
)
a
=
torch
.
rand
(
extra_msa_depth
,
n
,
model
.
globals
.
c_e
).
to
(
device
)
a
=
torch
.
rand
(
extra_msa_depth
,
n
,
model
.
globals
.
c_e
).
to
(
device
)
msa_mask
=
torch
.
randint
(
0
,
1
,
(
msa_depth
+
4
,
n
)).
to
(
device
)
# We need to do a dry run through the model so the chunk size tuners'
# We need to do a dry run through the model so the chunk size tuners'
# trial runs (which run during the first-ever model iteration) aren't
# trial runs (which run during the first-ever model iteration) aren't
...
@@ -140,10 +140,15 @@ def trace_model_(model, sample_input):
...
@@ -140,10 +140,15 @@ def trace_model_(model, sample_input):
# Yes, yes, I know
# Yes, yes, I know
with
contextlib
.
redirect_stderr
(
None
):
with
contextlib
.
redirect_stderr
(
None
):
traced_block
=
torch
.
jit
.
trace
(
block
,
block_inputs
)
traced_block
=
torch
.
jit
.
trace
(
block
,
block_inputs
)
traced_block
=
torch
.
jit
.
optimize_for_inference
(
traced_block
)
traced_block
=
torch
.
jit
.
freeze
(
traced_block
,
optimize_numerics
=
True
)
# It would be nice to use this, but its runtimes are extremely
# unpredictable
# traced_block = torch.jit.optimize_for_inference(traced_block)
# All trace inputs need to be tensors. This wrapper takes care of that
# All trace inputs need to be tensors. This wrapper takes care of that
def
traced_block_wrapper
(
*
args
,
**
kwargs
):
def
traced_block_wrapper
(
*
args
,
**
kwargs
):
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
)
if
type
(
t
)
!=
torch
.
Tensor
else
t
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
)
if
type
(
t
)
!=
torch
.
Tensor
else
t
args
=
[
to_tensor
(
a
)
for
a
in
args
]
args
=
[
to_tensor
(
a
)
for
a
in
args
]
kwargs
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
kwargs
.
items
()}
kwargs
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
kwargs
.
items
()}
...
@@ -162,35 +167,193 @@ def trace_model_(model, sample_input):
...
@@ -162,35 +167,193 @@ def trace_model_(model, sample_input):
fn_arg_names
=
fn_arg_names
[
1
:]
fn_arg_names
=
fn_arg_names
[
1
:]
# Trim unspecified arguments
# Trim unspecified arguments
fn_arg_names
=
fn_arg_names
[:
len
(
arg_list
)]
fn_arg_names
=
fn_arg_names
[:
len
(
arg_list
)]
name_tups
=
zip
(
fn_arg_names
,
[
n
for
n
,
_
in
arg_list
])
name_tups
=
list
(
zip
(
fn_arg_names
,
[
n
for
n
,
_
in
arg_list
]))
print
(
name_tups
)
assert
(
all
([
n1
==
n2
for
n1
,
n2
in
name_tups
]))
assert
(
all
([
n1
==
n2
for
n1
,
n2
in
name_tups
]))
evoformer_attn_chunk_size
=
max
(
evoformer_attn_chunk_size
=
max
(
model
.
globals
.
chunk_size
,
evoformer_chunk_size
//
4
model
.
globals
.
chunk_size
,
evoformer_chunk_size
//
4
)
)
evoformer_arg_tuples
=
[
# MSA row attention
msa_att_row_arg_tuples
=
[
(
"m"
,
m
),
(
"m"
,
m
),
(
"z"
,
z
),
(
"z"
,
z
),
(
"msa_mask"
,
msa_mask
),
(
"mask"
,
msa_mask
),
(
"pair_mask"
,
pair_mask
),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
msa_att_row
.
forward
,
msa_att_row_arg_tuples
)
msa_att_row_args
=
[
arg
for
_
,
arg
in
msa_att_row_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
msa_att_row
,
msa_att_row_args
)
del
b
.
msa_att_row
b
.
msa_att_row
=
traced_block
# MSA col attention
msa_att_col_arg_tuples
=
[
(
"m"
,
m
),
(
"mask"
,
msa_mask
),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_chunk_size
)),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_chunk_size
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_flash"
,
torch
.
tensor
(
model
.
globals
.
use_flash
)),
(
"use_flash"
,
torch
.
tensor
(
model
.
globals
.
use_flash
)),
(
"inplace_safe"
,
torch
.
tensor
(
1
)),
(
"_mask_trans"
,
torch
.
tensor
(
model
.
config
.
_mask_trans
)),
(
"_attn_chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
]
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
forward
,
evoformer_arg_tuples
)
verify_arg_order
(
evoformer_args
=
[
arg
for
_
,
arg
in
evoformer_arg_tuples
]
model
.
evoformer
.
blocks
[
0
].
msa_att_col
.
forward
,
msa_att_col_arg_tuples
)
msa_att_col_args
=
[
arg
for
_
,
arg
in
msa_att_col_arg_tuples
]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
traced_evoformer_stack
=
[]
for
b
in
model
.
evoformer
.
blocks
:
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
,
evoformer_args
)
traced_block
=
trace_block
(
traced_evoformer_stack
.
append
(
traced_block
)
b
.
msa_att_col
,
msa_att_col_args
)
del
b
.
msa_att_col
b
.
msa_att_col
=
traced_block
del
model
.
evoformer
.
blocks
# OPM
model
.
evoformer
.
blocks
=
traced_evoformer_stack
opm_arg_tuples
=
[
(
"m"
,
m
),
(
"mask"
,
msa_mask
.
float
()),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_chunk_size
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
outer_product_mean
.
forward
,
opm_arg_tuples
)
opm_args
=
[
arg
for
_
,
arg
in
opm_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
outer_product_mean
,
opm_args
)
del
b
.
core
.
outer_product_mean
b
.
core
.
outer_product_mean
=
traced_block
# Triangular multiplicative update (out)
tri_mul_out_arg_tuples
=
[
(
"z"
,
z
),
(
"mask"
,
pair_mask
.
float
()),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
(
"_add_with_inplace"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
.
forward
,
tri_mul_out_arg_tuples
)
tri_mul_out_args
=
[
arg
for
_
,
arg
in
tri_mul_out_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
tri_mul_out
,
tri_mul_out_args
)
del
b
.
core
.
tri_mul_out
b
.
core
.
tri_mul_out
=
traced_block
# Triangular multiplicative update (in)
tri_mul_in_arg_tuples
=
[
(
"z"
,
z
),
(
"mask"
,
pair_mask
.
float
()),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
(
"_add_with_inplace"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
.
forward
,
tri_mul_in_arg_tuples
)
tri_mul_in_args
=
[
arg
for
_
,
arg
in
tri_mul_in_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
tri_mul_in
,
tri_mul_in_args
)
del
b
.
core
.
tri_mul_in
b
.
core
.
tri_mul_in
=
traced_block
# Triangular attention (start)
tri_att_start_arg_tuples
=
[
(
"x"
,
z
),
(
"mask"
,
pair_mask
.
float
()),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
.
forward
,
tri_att_start_arg_tuples
)
tri_att_start_args
=
[
arg
for
_
,
arg
in
tri_att_start_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
tri_att_start
,
tri_att_start_args
)
del
b
.
core
.
tri_att_start
b
.
core
.
tri_att_start
=
traced_block
# Triangular attention (end)
tri_att_end_arg_tuples
=
[
(
"x"
,
z
.
transpose
(
-
2
,
-
3
)),
(
"mask"
,
pair_mask
.
transpose
(
-
1
,
-
2
).
float
()),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
.
forward
,
tri_att_end_arg_tuples
)
tri_att_end_args
=
[
arg
for
_
,
arg
in
tri_att_end_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
tri_att_end
,
tri_att_end_args
)
del
b
.
core
.
tri_att_end
b
.
core
.
tri_att_end
=
traced_block
#evoformer_arg_tuples = [
# ("m", m),
# ("z", z),
# ("msa_mask", msa_mask),
# ("pair_mask", pair_mask),
# ("chunk_size", torch.tensor(evoformer_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("use_flash", torch.tensor(model.globals.use_flash)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(evoformer_attn_chunk_size)),
#]
#verify_arg_order(model.evoformer.blocks[0].forward, evoformer_arg_tuples)
#evoformer_args = [arg for _, arg in evoformer_arg_tuples]
#with torch.no_grad():
# traced_evoformer_stack = []
# for b in model.evoformer.blocks:
# traced_block = trace_block(b, evoformer_args)
# traced_evoformer_stack.append(traced_block)
#del model.evoformer.blocks
#model.evoformer.blocks = traced_evoformer_stack
# with torch.no_grad():
# for b in model.evoformer.blocks:
# _ = b(*evoformer_args)
#
# with torch.no_grad():
# for b in model.evoformer.blocks:
# _ = b(*evoformer_args)
# extra_msa_attn_chunk_size = max(
# extra_msa_attn_chunk_size = max(
# model.globals.chunk_size, extra_msa_chunk_size // 4
# model.globals.chunk_size, extra_msa_chunk_size // 4
# )
# )
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment