Unverified Commit 2a67dc33 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Use inplace to save memory (#72)

Use inplace to save memory
parent efa8a9e4
...@@ -143,6 +143,27 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ ...@@ -143,6 +143,27 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--enable_workflow --enable_workflow
``` ```
#### inference with lower memory usage
Alphafold's embedding presentations take up a lot of memory as the sequence length increases. To reduce memory usage,
you should add parameter `--chunk_size [N]` and `--inplace` to cmdline or shell script `./inference.sh`.
The smaller you set N, the less memory will be used, but it will affect the speed. We can inference
a sequence of length 7000 in fp32 on a 80G A100.
```shell
python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
--output_dir ./ \
--gpus 2 \
--uniref90_database_path data/uniref90/uniref90.fasta \
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path data/pdb70/pdb70 \
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign` \
--chunk_size N \
--inplace
```
## Performance Benchmark ## Performance Benchmark
......
...@@ -99,6 +99,63 @@ class EvoformerBlock(nn.Module): ...@@ -99,6 +99,63 @@ class EvoformerBlock(nn.Module):
return m, z return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size))
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
m[0] = scatter(m[0], dim=1)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m[0] = self.msa_stack(m[0], z[0], msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
m[0] = gather(m[0], dim=0)
z[0] = gather(z[0], dim=0)
m[0] = m[0][:, :-padding_size, :]
z[0] = z[0][:-padding_size, :-padding_size, :]
return m, z
class ExtraMSABlock(nn.Module): class ExtraMSABlock(nn.Module):
def __init__( def __init__(
...@@ -183,6 +240,74 @@ class ExtraMSABlock(nn.Module): ...@@ -183,6 +240,74 @@ class ExtraMSABlock(nn.Module):
return m, z return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(
m[0], (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z[0] = torch.nn.functional.pad(
z[0], (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m[0] = scatter(m[0], dim=1) if not self.is_multimer else scatter(m[0], dim=2)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m[0] = gather(m[0], dim=1) if not self.is_multimer else gather(m[0], dim=2)
z[0] = gather(z[0], dim=1)
m[0] = m[0][:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z[0] = z[0][:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
return m, z
class TemplatePairStackBlock(nn.Module): class TemplatePairStackBlock(nn.Module):
def __init__( def __init__(
...@@ -268,4 +393,51 @@ class TemplatePairStackBlock(nn.Module): ...@@ -268,4 +393,51 @@ class TemplatePairStackBlock(nn.Module):
z = gather(z, dim=1) z = gather(z, dim=1)
z = z[:, :-padding_size, :-padding_size, :] z = z[:, :-padding_size, :-padding_size, :]
return z
def inplace(
self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
):
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
z[0] = z[0].cpu()
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
z[0] = scatter(z[0], dim=1)
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for i in range(z[0].shape[0]):
single = z[0][i].unsqueeze(-4).to(mask.device)
single_mask = mask[i].unsqueeze(-3)
single_mask_row = scatter(single_mask, dim=1)
single_mask_col = scatter(single_mask, dim=2)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = col_to_row(single)
single = self.TriangleAttentionStartingNode(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleAttentionEndingNode(single, single_mask_col)
single = self.PairTransition(single)
single = col_to_row(single)
z[0][i] = single.to(z[0].device)
# z = torch.cat(single_templates, dim=-4)
if self.last_block:
z[0] = gather(z[0], dim=1)
z[0] = z[0][:, :-padding_size, :-padding_size, :]
return z return z
\ No newline at end of file
...@@ -169,3 +169,15 @@ class ExtraMSAStack(nn.Module): ...@@ -169,3 +169,15 @@ class ExtraMSAStack(nn.Module):
node = self.MSATransition(node) node = self.MSATransition(node)
return node return node
def inplace(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias.inplace(node, pair, node_mask_row)
node[0] = row_to_col(node[0])
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention.inplace(node, node_mask_col)
node = self.MSATransition.inplace(node)
return node
\ No newline at end of file
...@@ -29,6 +29,7 @@ from fastfold.distributed.comm_async import gather_async, gather_async_opp, get_ ...@@ -29,6 +29,7 @@ from fastfold.distributed.comm_async import gather_async, gather_async_opp, get_
CHUNK_SIZE = None CHUNK_SIZE = None
DEBUG = False
def set_chunk_size(chunk_size): def set_chunk_size(chunk_size):
...@@ -94,15 +95,34 @@ class ChunkTransition(nn.Module): ...@@ -94,15 +95,34 @@ class ChunkTransition(nn.Module):
chunk_size = 48 chunk_size = 48
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim chunk_size = para_dim
else:
chunk_size = CHUNK_SIZE * 48
out = torch.empty_like(src) out = torch.empty_like(src)
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
if DEBUG and ax > 10:
break
x = self.norm(src[:, ax:ax + chunk_size, :, :]) x = self.norm(src[:, ax:ax + chunk_size, :, :])
x = self.linear2(F.relu(self.linear1(x))) x = self.linear2(F.relu(self.linear1(x)))
out[:, ax:ax + chunk_size, :, :] = x out[:, ax:ax + chunk_size, :, :] = x
out.add_(src) out.add_(src)
return out return out
def inplace(self, src):
para_dim = src[0].shape[1]
if CHUNK_SIZE == None:
chunk_size = para_dim
else:
chunk_size = CHUNK_SIZE * 48
for ax in range(0, para_dim, chunk_size):
if DEBUG and ax > 10:
break
x = self.norm(src[0][:, ax:ax + chunk_size, :, :])
x = self.linear2(F.relu(self.linear1(x)))
src[0][:, ax:ax + chunk_size, :, :] += x
return src
class OutProductMean(nn.Module): class OutProductMean(nn.Module):
...@@ -117,6 +137,7 @@ class OutProductMean(nn.Module): ...@@ -117,6 +137,7 @@ class OutProductMean(nn.Module):
n_feat_out, n_feat_out,
initializer='zero', initializer='zero',
use_bias=True) use_bias=True)
self.n_feat_proj = n_feat_proj
def forward(self, M, M_mask, Z_raw): def forward(self, M, M_mask, Z_raw):
M = self.layernormM(M) M = self.layernormM(M)
...@@ -148,6 +169,59 @@ class OutProductMean(nn.Module): ...@@ -148,6 +169,59 @@ class OutProductMean(nn.Module):
return Z_raw return Z_raw
def inplace(self, M, M_mask, Z_raw):
chunk_size = CHUNK_SIZE
if len(M.shape) == 4:
para_dim = M.shape[1]
left_act = torch.empty((M.shape[0], M.shape[1], M.shape[2], self.n_feat_proj), dtype=M.dtype, device=M.device)
right_act = torch.empty((M.shape[0], M.shape[1], M.shape[2], self.n_feat_proj), dtype=M.dtype, device=M.device)
if CHUNK_SIZE == None:
chunk_size = para_dim
else:
chunk_size = chunk_size * 32
for ax in range(0, para_dim, chunk_size):
m = self.layernormM(M[:, ax:ax + chunk_size, :, :])
right_act[:, ax:ax + chunk_size, :, :] = self.linear_b(m)
left_act[:, ax:ax + chunk_size, :, :] = self.linear_a(m)
else:
para_dim = M.shape[0]
left_act = torch.empty((M.shape[0], M.shape[1], self.n_feat_proj), dtype=M.dtype, device=M.device)
right_act = torch.empty((M.shape[0], M.shape[1], self.n_feat_proj), dtype=M.dtype, device=M.device)
if CHUNK_SIZE == None:
chunk_size = para_dim
else:
chunk_size = chunk_size * 32
for ax in range(0, para_dim, chunk_size):
m = self.layernormM(M[ax:ax + chunk_size, :, :])
right_act[ax:ax + chunk_size, :, :] = self.linear_b(m)
left_act[ax:ax + chunk_size, :, :] = self.linear_a(m)
right_act_all, work = gather_async(right_act, dim=2)
# right_act_all = gather(right_act, dim=2)
M_mask = M_mask.unsqueeze(-1)
M_mask_col = scatter(M_mask, dim=2)
left_act = M_mask_col * left_act
norm = torch.einsum('bsid,bsjd->bijd', M_mask_col, M_mask) + 1e-3
right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :]
Z_raw[0][:, ax:ax + chunk_size, :, :] += O / norm0
return Z_raw
class Linear(nn.Linear): class Linear(nn.Linear):
""" """
...@@ -316,6 +390,8 @@ class AsyncChunkTriangleMultiplicationOutgoing(nn.Module): ...@@ -316,6 +390,8 @@ class AsyncChunkTriangleMultiplicationOutgoing(nn.Module):
output = torch.empty_like(Z_raw) output = torch.empty_like(Z_raw)
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
zi = Z_raw[:, i:i + chunk_size, :, :] zi = Z_raw[:, i:i + chunk_size, :, :]
zi = self.layernorm1(zi) zi = self.layernorm1(zi)
gi = torch.sigmoid(self.left_right_gate(zi)) gi = torch.sigmoid(self.left_right_gate(zi))
...@@ -443,6 +519,8 @@ class AsyncChunkTriangleMultiplicationIncoming(nn.Module): ...@@ -443,6 +519,8 @@ class AsyncChunkTriangleMultiplicationIncoming(nn.Module):
output = torch.empty_like(Z_raw) output = torch.empty_like(Z_raw)
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
zi = Z_raw[:, :, i:i + chunk_size, :] zi = Z_raw[:, :, i:i + chunk_size, :]
zi = self.layernorm1(zi) zi = self.layernorm1(zi)
gi = torch.sigmoid(self.left_right_gate(zi)) gi = torch.sigmoid(self.left_right_gate(zi))
...@@ -577,6 +655,8 @@ class ChunkTriangleAttentionStartingNode(nn.Module): ...@@ -577,6 +655,8 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
output = torch.empty_like(Z_raw) output = torch.empty_like(Z_raw)
dropout_mask = torch.ones_like(z[:, 0:1, :, :], device=z.device, dtype=z.dtype) dropout_mask = torch.ones_like(z[:, 0:1, :, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
z_raw = Z_raw[:, i:i + chunk_size, :, :] z_raw = Z_raw[:, i:i + chunk_size, :, :]
z = self.layernorm1(z_raw) z = self.layernorm1(z_raw)
z_mask = Z_mask[:, i:i + chunk_size, :] z_mask = Z_mask[:, i:i + chunk_size, :]
...@@ -592,6 +672,52 @@ class ChunkTriangleAttentionStartingNode(nn.Module): ...@@ -592,6 +672,52 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
return output return output
def inplace(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw)
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
chunk_size = CHUNK_SIZE
para_dim = Z_raw[0].shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z_raw[0].shape[0], Z_raw[0].shape[1], Z_raw[0].shape[2], self.n_head), device=Z_raw[0].device, dtype=Z_raw[0].dtype)
for i in range(0, para_dim, chunk_size):
z = self.layernorm1(Z_raw[0][:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = self.linear_b(z)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
# output = torch.empty_like(Z_raw)
dropout_mask = torch.ones_like(z[:, 0:1, :, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
z_raw = Z_raw[0][:, i:i + chunk_size, :, :]
z = self.layernorm1(z_raw)
z_mask = Z_mask[:, i:i + chunk_size, :]
z = self.attention(z, z_mask, (b, -1))
z = bias_dropout_add(z,
self.out_bias,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
Z_raw[0][:, i:i + chunk_size, :, :] = z
return Z_raw
class ChunkMSARowAttentionWithPairBias(nn.Module): class ChunkMSARowAttentionWithPairBias(nn.Module):
...@@ -648,6 +774,8 @@ class ChunkMSARowAttentionWithPairBias(nn.Module): ...@@ -648,6 +774,8 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
output = torch.empty_like(M_raw) output = torch.empty_like(M_raw)
dropout_mask = torch.ones_like(M_raw[:, 0:1, :, :], device=M_raw.device, dtype=M_raw.dtype) dropout_mask = torch.ones_like(M_raw[:, 0:1, :, :], device=M_raw.device, dtype=M_raw.dtype)
for i in range(0, para_dim_m, chunk_size): for i in range(0, para_dim_m, chunk_size):
if DEBUG and i > 10:
break
m_raw = M_raw[:, i:i + chunk_size, :, :] m_raw = M_raw[:, i:i + chunk_size, :, :]
m = self.layernormM(m_raw) m = self.layernormM(m_raw)
m_mask = M_mask[:, i:i + chunk_size, :] m_mask = M_mask[:, i:i + chunk_size, :]
...@@ -662,6 +790,51 @@ class ChunkMSARowAttentionWithPairBias(nn.Module): ...@@ -662,6 +790,51 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
output[:, i:i + chunk_size, :, :] = m output[:, i:i + chunk_size, :, :] = m
return output return output
def inplace(self, M_raw, Z, M_mask):
if CHUNK_SIZE == None:
## Input projections
M = self.layernormM(M_raw)
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop, training=self.training)
chunk_size = CHUNK_SIZE
para_dim_z = Z[0].shape[1]
para_dim_m = M_raw[0].shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z[0].shape[0], Z[0].shape[1], Z[0].shape[2], self.n_head), device=Z[0].device, dtype=Z[0].dtype)
for i in range(0, para_dim_z, chunk_size):
z = self.layernormZ(Z[0][:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = F.linear(z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
dropout_mask = torch.ones_like(M_raw[0][:, 0:1, :, :], device=M_raw[0].device, dtype=M_raw[0].dtype)
for i in range(0, para_dim_m, chunk_size):
if DEBUG and i > 10:
break
m_raw = M_raw[0][:, i:i + chunk_size, :, :]
m = self.layernormM(m_raw)
m_mask = M_mask[:, i:i + chunk_size, :]
m = self.attention(m, m_mask, (b, -1))
m = bias_dropout_add(m,
self.out_bias,
dropout_mask,
m_raw,
prob=self.p_drop,
training=self.training)
M_raw[0][:, i:i + chunk_size, :, :] = m
return M_raw
class ChunkTriangleAttentionEndingNode(nn.Module): class ChunkTriangleAttentionEndingNode(nn.Module):
...@@ -716,6 +889,8 @@ class ChunkTriangleAttentionEndingNode(nn.Module): ...@@ -716,6 +889,8 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
output = torch.empty_like(Z_raw) output = torch.empty_like(Z_raw)
dropout_mask = torch.ones_like(Z_raw[:, :, 0:1, :], device=z.device, dtype=z.dtype) dropout_mask = torch.ones_like(Z_raw[:, :, 0:1, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
z_raw = Z_raw[:, :, i:i + chunk_size, :] z_raw = Z_raw[:, :, i:i + chunk_size, :]
z = self.layernorm1(z_raw.transpose(-2, -3)) z = self.layernorm1(z_raw.transpose(-2, -3))
z_mask = Z_mask[:, :, i:i + chunk_size].transpose(-1, -2) z_mask = Z_mask[:, :, i:i + chunk_size].transpose(-1, -2)
...@@ -731,6 +906,57 @@ class ChunkTriangleAttentionEndingNode(nn.Module): ...@@ -731,6 +906,57 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
return output return output
def inplace(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = Z_raw.transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
para_dim = Z_raw[0].shape[2]
chunk_size = CHUNK_SIZE
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z_raw[0].shape[0], Z_raw[0].shape[2], Z_raw[0].shape[1], self.n_head), device=Z_raw[0].device, dtype=Z_raw[0].dtype)
for i in range(0, para_dim, chunk_size):
z = Z_raw[0][:, :, i:i + chunk_size, :].transpose(-2, -3)
z = self.layernorm1(z)
b[:, i:i + chunk_size, :, :] = self.linear_b(z)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
dropout_mask = torch.ones_like(Z_raw[0][:, :, 0:1, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
z_raw = Z_raw[0][:, :, i:i + chunk_size, :]
z = self.layernorm1(z_raw.transpose(-2, -3))
z_mask = Z_mask[:, :, i:i + chunk_size].transpose(-1, -2)
z = self.attention(z, z_mask, (b, -1)).transpose(-2, -3)
z = bias_dropout_add(z,
self.out_bias,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
Z_raw[0][:, :, i:i + chunk_size, :] = z
return Z_raw
class ChunkMSAColumnGlobalAttention(nn.Module): class ChunkMSAColumnGlobalAttention(nn.Module):
def __init__(self, d_node, c=8, n_head=8): def __init__(self, d_node, c=8, n_head=8):
...@@ -754,6 +980,8 @@ class ChunkMSAColumnGlobalAttention(nn.Module): ...@@ -754,6 +980,8 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3) m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
m = self.layernormM(m) m = self.layernormM(m)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2) m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
...@@ -763,6 +991,26 @@ class ChunkMSAColumnGlobalAttention(nn.Module): ...@@ -763,6 +991,26 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
return M_raw return M_raw
def inplace(self, M_raw, M_mask):
para_dim = M_raw[0].shape[2]
if CHUNK_SIZE is None:
chunk_size = para_dim
else:
chunk_size = CHUNK_SIZE
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
m = M_raw[0][:, :, i:i + chunk_size, :].transpose(-2, -3)
m = self.layernormM(m)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
m = self.global_attention(m, m_mask)
m = m.transpose(-2, -3)
M_raw[0][:, :, i:i + chunk_size, :] += m
return M_raw
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
......
...@@ -247,3 +247,21 @@ class PairStack(nn.Module): ...@@ -247,3 +247,21 @@ class PairStack(nn.Module):
pair = self.PairTransition(pair) pair = self.PairTransition(pair)
pair = col_to_row(pair) pair = col_to_row(pair)
return pair return pair
def inplace(self, pair, pair_mask):
pair_mask_row = scatter(pair_mask, dim=1)
pair_mask_col = scatter(pair_mask, dim=2)
pair[0] = self.TriangleMultiplicationOutgoing(pair[0], pair_mask_row)
pair[0] = row_to_col(pair[0])
pair[0] = self.TriangleMultiplicationIncoming(pair[0], pair_mask_col)
pair[0] = col_to_row(pair[0])
pair = self.TriangleAttentionStartingNode.inplace(pair, pair_mask_row)
pair[0] = row_to_col(pair[0])
pair = self.TriangleAttentionEndingNode.inplace(pair, pair_mask_col)
pair = self.PairTransition.inplace(pair)
pair[0] = col_to_row(pair[0])
return pair
\ No newline at end of file
...@@ -281,7 +281,8 @@ class AlphaFold(nn.Module): ...@@ -281,7 +281,8 @@ class AlphaFold(nn.Module):
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
no_batch_dims, no_batch_dims,
self.globals.chunk_size self.globals.chunk_size,
inplace=self.globals.inplace
) )
if( if(
...@@ -320,28 +321,54 @@ class AlphaFold(nn.Module): ...@@ -320,28 +321,54 @@ class AlphaFold(nn.Module):
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat) extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.extra_msa_stack( if not self.globals.inplace or self.globals.is_multimer:
extra_msa_feat, z = self.extra_msa_stack(
z, extra_msa_feat,
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype), z,
chunk_size=self.globals.chunk_size, msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
pair_mask=pair_mask.to(dtype=z.dtype), chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans, pair_mask=pair_mask.to(dtype=z.dtype),
) _mask_trans=self.config._mask_trans,
)
else:
extra_msa_feat = [extra_msa_feat]
z = [z]
z = self.extra_msa_stack.inplace(
extra_msa_feat,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat[0].dtype),
chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z[0].dtype),
_mask_trans=self.config._mask_trans,
)[0]
del extra_msa_feat, extra_msa_fn del extra_msa_feat, extra_msa_fn
# Run MSA + pair embeddings through the trunk of the network # Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m] # m: [*, S, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
# s: [*, N, C_s] # s: [*, N, C_s]
m, z, s = self.evoformer( if not self.globals.inplace or self.globals.is_multimer:
m, m, z, s = self.evoformer(
z, m,
msa_mask=msa_mask.to(dtype=m.dtype), z,
pair_mask=pair_mask.to(dtype=z.dtype), msa_mask=msa_mask.to(dtype=m.dtype),
chunk_size=self.globals.chunk_size, pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans, chunk_size=self.globals.chunk_size,
) _mask_trans=self.config._mask_trans,
)
else:
m = [m]
z = [z]
m, z, s = self.evoformer.inplace(
m,
z,
msa_mask=msa_mask.to(dtype=m[0].dtype),
pair_mask=pair_mask.to(dtype=z[0].dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
m = m[0]
z = z[0]
outputs["msa"] = m[..., :n_seq, :, :] outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z outputs["pair"] = z
......
...@@ -162,7 +162,8 @@ class TemplateEmbedder(nn.Module): ...@@ -162,7 +162,8 @@ class TemplateEmbedder(nn.Module):
pair_mask, pair_mask,
templ_dim, templ_dim,
chunk_size, chunk_size,
_mask_trans=True _mask_trans=True,
inplace=False
): ):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
...@@ -205,14 +206,25 @@ class TemplateEmbedder(nn.Module): ...@@ -205,14 +206,25 @@ class TemplateEmbedder(nn.Module):
# single_template_embeds.update({"pair": t}) # single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds) template_embeds.append(single_template_embeds)
# [*, S_t, N, N, C_z] # [*, S_t, N, N, C_z]
t[i] = self.template_pair_stack( if inplace:
tt, tt = [tt]
pair_mask.unsqueeze(-3).to(dtype=z.dtype), t[i] = self.template_pair_stack.inplace(
chunk_size=chunk_size, tt,
_mask_trans=_mask_trans, pair_mask.unsqueeze(-3).to(dtype=z.dtype),
).to(t.device) chunk_size=chunk_size,
del tt, single_template_embeds, single_template_feats _mask_trans=_mask_trans,
)[0].to(t.device)
else:
t[i] = self.template_pair_stack(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
).to(t.device)
del tt, single_template_feats
template_embeds = dict_multimap( template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim), partial(torch.cat, dim=templ_dim),
...@@ -220,21 +232,17 @@ class TemplateEmbedder(nn.Module): ...@@ -220,21 +232,17 @@ class TemplateEmbedder(nn.Module):
) )
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( z = self.template_pointwise_att(
t.to(z.device), t,
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size, chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {} ret = {}
if self.config.embed_angles: if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"] ret["template_single_embedding"] = template_embeds["angle"]
z += t
return ret, z return ret, z
......
...@@ -533,6 +533,60 @@ class EvoformerStack(nn.Module): ...@@ -533,6 +533,60 @@ class EvoformerStack(nn.Module):
return m, z, s return m, z, s
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b.inplace,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[0][..., 0, :, :])
return m, z, s
class ExtraMSAStack(nn.Module): class ExtraMSAStack(nn.Module):
""" """
...@@ -626,3 +680,49 @@ class ExtraMSAStack(nn.Module): ...@@ -626,3 +680,49 @@ class ExtraMSAStack(nn.Module):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return z return z
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b.inplace(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
\ No newline at end of file
...@@ -122,11 +122,26 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -122,11 +122,26 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z] # [*, N_res, N_res, 1, C_z]
biases = [bias] biases = [bias]
if chunk_size is not None: if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size) para_dim_t0 = t.shape[0]
para_dim_t1 = t.shape[1]
chunk_size_t = chunk_size * 4
mask = torch.sum(template_mask.to(z.device)) > 0
for ti in range(0, para_dim_t0, chunk_size_t):
t0 = t[ti:ti + chunk_size_t, :, :, :]
t0 = t0.to(z.device)
para_dim_t_part = t0.shape[0]
for i in range(0, para_dim_t_part, chunk_size):
for j in range(0, para_dim_t1, chunk_size):
z[i:i + chunk_size, j:j + chunk_size, :, :] += self.mha(
q_x=z[i + ti:i + ti + chunk_size, j:j + chunk_size, :, :], kv_x=t0[i:i + chunk_size, j:j + chunk_size, :, :], biases=biases
) * mask
else: else:
z = self.mha(q_x=z, kv_x=t, biases=biases) t = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z] t = t * (torch.sum(template_mask) > 0)
z = z + t
z = z.squeeze(-2) z = z.squeeze(-2)
return z return z
...@@ -358,3 +373,43 @@ class TemplatePairStack(nn.Module): ...@@ -358,3 +373,43 @@ class TemplatePairStack(nn.Module):
for i in range(0, t.shape[0], chunk_size): for i in range(0, t.shape[0], chunk_size):
t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size]) t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size])
return t return t
def inplace(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t[0].shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
partial(
b.inplace,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
if chunk_size is None:
chunk_size = t[0].shape[0]
for i in range(0, t[0].shape[0], chunk_size):
t[0][i:i + chunk_size] = self.layer_norm(t[0][i:i + chunk_size].to(mask.device)).to(t[0].device)
return t
...@@ -101,6 +101,7 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -101,6 +101,7 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument('--release_dates_path', type=str, default=None) parser.add_argument('--release_dates_path', type=str, default=None)
parser.add_argument('--chunk_size', type=int, default=None) parser.add_argument('--chunk_size', type=int, default=None)
parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not') parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not')
parser.add_argument('--inplace', default=False, action='store_true')
def inference_model(rank, world_size, result_q, batch, args): def inference_model(rank, world_size, result_q, batch, args):
...@@ -113,6 +114,7 @@ def inference_model(rank, world_size, result_q, batch, args): ...@@ -113,6 +114,7 @@ def inference_model(rank, world_size, result_q, batch, args):
config = model_config(args.model_name) config = model_config(args.model_name)
if args.chunk_size: if args.chunk_size:
config.globals.chunk_size = args.chunk_size config.globals.chunk_size = args.chunk_size
config.globals.inplace = args.inplace
model = AlphaFold(config) model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name) import_jax_weights_(model, args.param_path, version=args.model_name)
......
# add `--gpus [N]` to use N gpus for inference # add '--gpus [N]' to use N gpus for inference
# add `--enable_workflow` to use parallel workflow for data processing # add '--enable_workflow' to use parallel workflow for data processing
# add `--use_precomputed_alignments [path_to_alignments]` to use precomputed msa # add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa
# add '--chunk_size [N]' to use chunk to reduce peak memory
# add '--inplace' to use inplace to save memory
python inference.py target.fasta data/pdb_mmcif/mmcif_files \ python inference.py target.fasta data/pdb_mmcif/mmcif_files \
--output_dir ./ \ --output_dir ./ \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment