Commit e56b5976 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix template masking bug

parent 4d513bb1
...@@ -183,10 +183,16 @@ class AlphaFold(nn.Module): ...@@ -183,10 +183,16 @@ class AlphaFold(nn.Module):
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
) )
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
if(inplace_safe): if(inplace_safe):
t *= (torch.sum(batch["template_mask"], dim=-1) > 0) t *= t_mask
else: else:
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0) t = t * t_mask
ret = {} ret = {}
...@@ -380,7 +386,6 @@ class AlphaFold(nn.Module): ...@@ -380,7 +386,6 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype), pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
......
...@@ -168,7 +168,6 @@ def trace_model_(model, sample_input): ...@@ -168,7 +168,6 @@ def trace_model_(model, sample_input):
# Trim unspecified arguments # Trim unspecified arguments
fn_arg_names = fn_arg_names[:len(arg_list)] fn_arg_names = fn_arg_names[:len(arg_list)]
name_tups = list(zip(fn_arg_names, [n for n, _ in arg_list])) name_tups = list(zip(fn_arg_names, [n for n, _ in arg_list]))
print(name_tups)
assert(all([n1 == n2 for n1, n2 in name_tups])) assert(all([n1 == n2 for n1, n2 in name_tups]))
evoformer_attn_chunk_size = max( evoformer_attn_chunk_size = max(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment