Commit 1df4991d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

DeepSpeed + PL bfloat16 working

parent 02fc4376
......@@ -148,7 +148,7 @@ config = mlc.ConfigDict(
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 2048,
"max_extra_msa": 5120,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
......@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 1,
"num_workers": 2,
},
},
},
......
......@@ -17,7 +17,7 @@ import torch
import torch.nn as nn
from typing import Tuple
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import one_hot
......@@ -168,8 +168,8 @@ class RecyclingEmbedder(nn.Module):
self.bins = None
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = nn.LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
......
......@@ -19,7 +19,7 @@ import torch.nn as nn
from typing import Tuple, Optional
from functools import partial
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
from openfold.model.msa import (
MSARowAttentionWithPairBias,
......@@ -61,7 +61,7 @@ class MSATransition(nn.Module):
self.c_m = c_m
self.n = n
self.layer_norm = nn.LayerNorm(self.c_m)
self.layer_norm = LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
......@@ -355,8 +355,8 @@ class ExtraMSABlock(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
m.clone(),
z=z.clone(),
mask=msa_mask,
chunk_size=chunk_size,
_chunk_logits=_chunk_logits,
......@@ -372,7 +372,7 @@ class ExtraMSABlock(nn.Module):
return m, z
if(self.ckpt):
if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
else:
......@@ -596,27 +596,27 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
checkpoint_fn = get_checkpoint_fn()
blocks = [
partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
]
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
def dodo(b, *args):
torch.cuda.empty_cache()
return b(*args)
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
blocks = [partial(dodo, b) for b in blocks]
#blocks = [partial(dodo, b) for b in blocks]
for b in blocks:
if(torch.is_grad_enabled()):
m, z = checkpoint_fn(b, m, z)
else:
m, z = b(m, z)
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
#for b in self.blocks:
# m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
# if(self.clear_cache_between_blocks):
# torch.cuda.empty_cache()
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
......@@ -16,7 +16,7 @@
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.loss import (
compute_plddt,
compute_tm,
......@@ -96,7 +96,7 @@ class PerResidueLDDTCaPredictor(nn.Module):
self.c_in = c_in
self.c_hidden = c_hidden
self.layer_norm = nn.LayerNorm(self.c_in)
self.layer_norm = LayerNorm(self.c_in)
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
......
......@@ -134,7 +134,7 @@ class AlphaFold(nn.Module):
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
)
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
......@@ -175,6 +175,12 @@ class AlphaFold(nn.Module):
# Primary output dictionary
outputs = {}
# This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
......@@ -217,7 +223,9 @@ class AlphaFold(nn.Module):
requires_grad=False,
)
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment