"git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "6a31be8f76da915bc20fdf8d1de4d31f9e93faf8"
Commit e56b5976 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix template masking bug

parent 4d513bb1
......@@ -183,10 +183,16 @@ class AlphaFold(nn.Module):
use_lma=self.globals.use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
if(inplace_safe):
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
t *= t_mask
else:
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
t = t * t_mask
ret = {}
......@@ -380,7 +386,6 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
_mask_trans=self.config._mask_trans,
)
......
......@@ -168,7 +168,6 @@ def trace_model_(model, sample_input):
# Trim unspecified arguments
fn_arg_names = fn_arg_names[:len(arg_list)]
name_tups = list(zip(fn_arg_names, [n for n, _ in arg_list]))
print(name_tups)
assert(all([n1 == n2 for n1, n2 in name_tups]))
evoformer_attn_chunk_size = max(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment