Commit b399eaac authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add more in-place operations

parent bcd78085
...@@ -56,7 +56,7 @@ class Dropout(nn.Module): ...@@ -56,7 +56,7 @@ class Dropout(nn.Module):
shape[bd] = 1 shape[bd] = 1
mask = x.new_ones(shape) mask = x.new_ones(shape)
mask = self.dropout(mask) mask = self.dropout(mask)
x = x * mask x *= mask
return x return x
......
...@@ -223,19 +223,19 @@ class EvoformerBlock(nn.Module): ...@@ -223,19 +223,19 @@ class EvoformerBlock(nn.Module):
m = m + self.msa_dropout_layer( m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size) self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
) )
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) m += self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_transition( m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size m, mask=msa_trans_mask, chunk_size=chunk_size
) )
z = z + self.outer_product_mean( z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size m, mask=msa_mask, chunk_size=chunk_size
) )
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) z += self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) z += self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer( z += self.ps_dropout_row_layer(
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size) self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
) )
z = z + self.ps_dropout_col_layer( z += self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size) self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
) )
z = z + self.pair_transition( z = z + self.pair_transition(
......
...@@ -332,7 +332,7 @@ class GlobalAttention(nn.Module): ...@@ -332,7 +332,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
q = self.linear_q(q) q = self.linear_q(q)
q = q * (self.c_hidden ** (-0.5)) q *= (self.c_hidden ** (-0.5))
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1)) q = q.view(q.shape[:-1] + (self.no_heads, -1))
...@@ -347,7 +347,7 @@ class GlobalAttention(nn.Module): ...@@ -347,7 +347,7 @@ class GlobalAttention(nn.Module):
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
) )
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
a = a + bias a += bias
a = self.softmax(a) a = self.softmax(a)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment