Commit b399eaac authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add more in-place operations

parent bcd78085
......@@ -56,7 +56,7 @@ class Dropout(nn.Module):
shape[bd] = 1
mask = x.new_ones(shape)
mask = self.dropout(mask)
x = x * mask
x *= mask
return x
......
......@@ -223,19 +223,19 @@ class EvoformerBlock(nn.Module):
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m += self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(
z += self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z += self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z += self.ps_dropout_row_layer(
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
)
z = z + self.ps_dropout_col_layer(
z += self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
)
z = z + self.pair_transition(
......
......@@ -332,7 +332,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * (self.c_hidden ** (-0.5))
q *= (self.c_hidden ** (-0.5))
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
......@@ -347,7 +347,7 @@ class GlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a = a + bias
a += bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment