"vscode:/vscode.git/clone" did not exist on "635f1e94e855d7832363ecdb2ed70affe487608a"
Commit a9a12890 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make tracing more granular

parent e8b3789f
...@@ -358,6 +358,7 @@ class ChunkSizeTuner: ...@@ -358,6 +358,7 @@ class ChunkSizeTuner:
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)] candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size] candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates candidates = [min_chunk_size] + candidates
candidates[-1] += 4
def test_chunk_size(chunk_size): def test_chunk_size(chunk_size):
try: try:
......
...@@ -71,15 +71,15 @@ def trace_model_(model, sample_input): ...@@ -71,15 +71,15 @@ def trace_model_(model, sample_input):
seq_mask = feats["seq_mask"].to(device) seq_mask = feats["seq_mask"].to(device)
pair_mask = seq_mask[..., None] * seq_mask[..., None, :] pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"].to(device)
extra_msa_mask = feats["extra_msa_mask"].to(device) extra_msa_mask = feats["extra_msa_mask"].to(device)
template_pair_mask = torch.stack([pair_mask] * no_templates, dim=-3) template_pair_mask = torch.stack([pair_mask] * no_templates, dim=-3)
# Create some fake representations with the correct shapes # Create some fake representations with the correct shapes
m = torch.rand(msa_depth, n, model.globals.c_m).to(device) m = torch.rand(msa_depth + 4, n, model.globals.c_m).to(device)
z = torch.rand(n, n, model.globals.c_z).to(device) z = torch.rand(n, n, model.globals.c_z).to(device)
t = torch.rand(no_templates, n, n, model.globals.c_t).to(device) t = torch.rand(no_templates, n, n, model.globals.c_t).to(device)
a = torch.rand(extra_msa_depth, n, model.globals.c_e).to(device) a = torch.rand(extra_msa_depth, n, model.globals.c_e).to(device)
msa_mask = torch.randint(0, 1, (msa_depth + 4, n)).to(device)
# We need to do a dry run through the model so the chunk size tuners' # We need to do a dry run through the model so the chunk size tuners'
# trial runs (which run during the first-ever model iteration) aren't # trial runs (which run during the first-ever model iteration) aren't
...@@ -140,10 +140,15 @@ def trace_model_(model, sample_input): ...@@ -140,10 +140,15 @@ def trace_model_(model, sample_input):
# Yes, yes, I know # Yes, yes, I know
with contextlib.redirect_stderr(None): with contextlib.redirect_stderr(None):
traced_block = torch.jit.trace(block, block_inputs) traced_block = torch.jit.trace(block, block_inputs)
traced_block = torch.jit.optimize_for_inference(traced_block)
traced_block = torch.jit.freeze(traced_block, optimize_numerics=True)
# It would be nice to use this, but its runtimes are extremely
# unpredictable
# traced_block = torch.jit.optimize_for_inference(traced_block)
# All trace inputs need to be tensors. This wrapper takes care of that # All trace inputs need to be tensors. This wrapper takes care of that
def traced_block_wrapper(*args, **kwargs): def traced_block_wrapper(*args, **kwargs):
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t
args = [to_tensor(a) for a in args] args = [to_tensor(a) for a in args]
kwargs = {k: to_tensor(v) for k,v in kwargs.items()} kwargs = {k: to_tensor(v) for k,v in kwargs.items()}
...@@ -162,35 +167,193 @@ def trace_model_(model, sample_input): ...@@ -162,35 +167,193 @@ def trace_model_(model, sample_input):
fn_arg_names = fn_arg_names[1:] fn_arg_names = fn_arg_names[1:]
# Trim unspecified arguments # Trim unspecified arguments
fn_arg_names = fn_arg_names[:len(arg_list)] fn_arg_names = fn_arg_names[:len(arg_list)]
name_tups = zip(fn_arg_names, [n for n, _ in arg_list]) name_tups = list(zip(fn_arg_names, [n for n, _ in arg_list]))
print(name_tups)
assert(all([n1 == n2 for n1, n2 in name_tups])) assert(all([n1 == n2 for n1, n2 in name_tups]))
evoformer_attn_chunk_size = max( evoformer_attn_chunk_size = max(
model.globals.chunk_size, evoformer_chunk_size // 4 model.globals.chunk_size, evoformer_chunk_size // 4
) )
evoformer_arg_tuples = [
# MSA row attention
msa_att_row_arg_tuples = [
("m", m), ("m", m),
("z", z), ("z", z),
("msa_mask", msa_mask), ("mask", msa_mask),
("pair_mask", pair_mask), ("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_lma", torch.tensor(model.globals.use_lma)),
]
verify_arg_order(
model.evoformer.blocks[0].msa_att_row.forward,
msa_att_row_arg_tuples
)
msa_att_row_args = [arg for _, arg in msa_att_row_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.msa_att_row, msa_att_row_args
)
del b.msa_att_row
b.msa_att_row = traced_block
# MSA col attention
msa_att_col_arg_tuples = [
("m", m),
("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_chunk_size)), ("chunk_size", torch.tensor(evoformer_chunk_size)),
("use_lma", torch.tensor(model.globals.use_lma)), ("use_lma", torch.tensor(model.globals.use_lma)),
("use_flash", torch.tensor(model.globals.use_flash)), ("use_flash", torch.tensor(model.globals.use_flash)),
("inplace_safe", torch.tensor(1)),
("_mask_trans", torch.tensor(model.config._mask_trans)),
("_attn_chunk_size", torch.tensor(evoformer_attn_chunk_size)),
] ]
verify_arg_order(model.evoformer.blocks[0].forward, evoformer_arg_tuples) verify_arg_order(
evoformer_args = [arg for _, arg in evoformer_arg_tuples] model.evoformer.blocks[0].msa_att_col.forward,
msa_att_col_arg_tuples
)
msa_att_col_args = [arg for _, arg in msa_att_col_arg_tuples]
with torch.no_grad(): with torch.no_grad():
traced_evoformer_stack = []
for b in model.evoformer.blocks: for b in model.evoformer.blocks:
traced_block = trace_block(b, evoformer_args) traced_block = trace_block(
traced_evoformer_stack.append(traced_block) b.msa_att_col, msa_att_col_args
)
del b.msa_att_col
b.msa_att_col = traced_block
del model.evoformer.blocks # OPM
model.evoformer.blocks = traced_evoformer_stack opm_arg_tuples = [
("m", m),
("mask", msa_mask.float()),
("chunk_size", torch.tensor(evoformer_chunk_size)),
("inplace_safe", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.outer_product_mean.forward,
opm_arg_tuples
)
opm_args = [arg for _, arg in opm_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.outer_product_mean, opm_args
)
del b.core.outer_product_mean
b.core.outer_product_mean = traced_block
# Triangular multiplicative update (out)
tri_mul_out_arg_tuples = [
("z", z),
("mask", pair_mask.float()),
("inplace_safe", torch.tensor(True)),
("_add_with_inplace", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.tri_mul_out.forward,
tri_mul_out_arg_tuples
)
tri_mul_out_args = [arg for _, arg in tri_mul_out_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.tri_mul_out, tri_mul_out_args
)
del b.core.tri_mul_out
b.core.tri_mul_out = traced_block
# Triangular multiplicative update (in)
tri_mul_in_arg_tuples = [
("z", z),
("mask", pair_mask.float()),
("inplace_safe", torch.tensor(True)),
("_add_with_inplace", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.tri_mul_in.forward,
tri_mul_in_arg_tuples
)
tri_mul_in_args = [arg for _, arg in tri_mul_in_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.tri_mul_in, tri_mul_in_args
)
del b.core.tri_mul_in
b.core.tri_mul_in = traced_block
# Triangular attention (start)
tri_att_start_arg_tuples = [
("x", z),
("mask", pair_mask.float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.tri_att_start.forward,
tri_att_start_arg_tuples
)
tri_att_start_args = [arg for _, arg in tri_att_start_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.tri_att_start, tri_att_start_args
)
del b.core.tri_att_start
b.core.tri_att_start = traced_block
# Triangular attention (end)
tri_att_end_arg_tuples = [
("x", z.transpose(-2, -3)),
("mask", pair_mask.transpose(-1, -2).float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.tri_att_end.forward,
tri_att_end_arg_tuples
)
tri_att_end_args = [arg for _, arg in tri_att_end_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.tri_att_end, tri_att_end_args
)
del b.core.tri_att_end
b.core.tri_att_end = traced_block
#evoformer_arg_tuples = [
# ("m", m),
# ("z", z),
# ("msa_mask", msa_mask),
# ("pair_mask", pair_mask),
# ("chunk_size", torch.tensor(evoformer_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("use_flash", torch.tensor(model.globals.use_flash)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(evoformer_attn_chunk_size)),
#]
#verify_arg_order(model.evoformer.blocks[0].forward, evoformer_arg_tuples)
#evoformer_args = [arg for _, arg in evoformer_arg_tuples]
#with torch.no_grad():
# traced_evoformer_stack = []
# for b in model.evoformer.blocks:
# traced_block = trace_block(b, evoformer_args)
# traced_evoformer_stack.append(traced_block)
#del model.evoformer.blocks
#model.evoformer.blocks = traced_evoformer_stack
# with torch.no_grad():
# for b in model.evoformer.blocks:
# _ = b(*evoformer_args)
#
# with torch.no_grad():
# for b in model.evoformer.blocks:
# _ = b(*evoformer_args)
# extra_msa_attn_chunk_size = max( # extra_msa_attn_chunk_size = max(
# model.globals.chunk_size, extra_msa_chunk_size // 4 # model.globals.chunk_size, extra_msa_chunk_size // 4
# ) # )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment