Commit 1df4991d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

DeepSpeed + PL bfloat16 working

parent 02fc4376
...@@ -148,7 +148,7 @@ config = mlc.ConfigDict( ...@@ -148,7 +148,7 @@ config = mlc.ConfigDict(
"same_prob": 0.1, "same_prob": 0.1,
"uniform_prob": 0.1, "uniform_prob": 0.1,
}, },
"max_extra_msa": 2048, "max_extra_msa": 5120,
"max_recycling_iters": 3, "max_recycling_iters": 3,
"msa_cluster_features": True, "msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False, "reduce_msa_clusters_by_max_templates": False,
...@@ -226,7 +226,7 @@ config = mlc.ConfigDict( ...@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False, "use_small_bfd": False,
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 1, "num_workers": 2,
}, },
}, },
}, },
......
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import one_hot from openfold.utils.tensor_utils import one_hot
...@@ -168,8 +168,8 @@ class RecyclingEmbedder(nn.Module): ...@@ -168,8 +168,8 @@ class RecyclingEmbedder(nn.Module):
self.bins = None self.bins = None
self.linear = Linear(self.no_bins, self.c_z) self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = nn.LayerNorm(self.c_m) self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
def forward( def forward(
self, self,
......
...@@ -19,7 +19,7 @@ import torch.nn as nn ...@@ -19,7 +19,7 @@ import torch.nn as nn
from typing import Tuple, Optional from typing import Tuple, Optional
from functools import partial from functools import partial
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
from openfold.model.msa import ( from openfold.model.msa import (
MSARowAttentionWithPairBias, MSARowAttentionWithPairBias,
...@@ -61,7 +61,7 @@ class MSATransition(nn.Module): ...@@ -61,7 +61,7 @@ class MSATransition(nn.Module):
self.c_m = c_m self.c_m = c_m
self.n = n self.n = n
self.layer_norm = nn.LayerNorm(self.c_m) self.layer_norm = LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
...@@ -355,8 +355,8 @@ class ExtraMSABlock(nn.Module): ...@@ -355,8 +355,8 @@ class ExtraMSABlock(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer( m = m + self.msa_dropout_layer(
self.msa_att_row( self.msa_att_row(
m, m.clone(),
z=z, z=z.clone(),
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
_chunk_logits=_chunk_logits, _chunk_logits=_chunk_logits,
...@@ -372,7 +372,7 @@ class ExtraMSABlock(nn.Module): ...@@ -372,7 +372,7 @@ class ExtraMSABlock(nn.Module):
return m, z return m, z
if(self.ckpt): if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn() checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z) m, z = checkpoint_fn(fn, m, z)
else: else:
...@@ -596,27 +596,27 @@ class ExtraMSAStack(nn.Module): ...@@ -596,27 +596,27 @@ class ExtraMSAStack(nn.Module):
Returns: Returns:
[*, N_res, N_res, C_z] pair update [*, N_res, N_res, C_z] pair update
""" """
checkpoint_fn = get_checkpoint_fn() #checkpoint_fn = get_checkpoint_fn()
blocks = [ #blocks = [
partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
] #]
def dodo(b, *args): #def dodo(b, *args):
torch.cuda.empty_cache() # torch.cuda.empty_cache()
return b(*args) # return b(*args)
blocks = [partial(dodo, b) for b in blocks] #blocks = [partial(dodo, b) for b in blocks]
for b in blocks: #for b in blocks:
if(torch.is_grad_enabled()): # if(torch.is_grad_enabled()):
m, z = checkpoint_fn(b, m, z) # m, z = checkpoint_fn(b, *(m, z))
else: # else:
m, z = b(m, z) # m, z = b(m, z)
#for b in self.blocks: for b in self.blocks:
# m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
# if(self.clear_cache_between_blocks): if(self.clear_cache_between_blocks):
# torch.cuda.empty_cache() torch.cuda.empty_cache()
return z return z
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.loss import ( from openfold.utils.loss import (
compute_plddt, compute_plddt,
compute_tm, compute_tm,
...@@ -96,7 +96,7 @@ class PerResidueLDDTCaPredictor(nn.Module): ...@@ -96,7 +96,7 @@ class PerResidueLDDTCaPredictor(nn.Module):
self.c_in = c_in self.c_in = c_in
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.layer_norm = nn.LayerNorm(self.c_in) self.layer_norm = LayerNorm(self.c_in)
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu") self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu") self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
......
...@@ -134,7 +134,7 @@ class AlphaFold(nn.Module): ...@@ -134,7 +134,7 @@ class AlphaFold(nn.Module):
inf=self.config.template.inf, inf=self.config.template.inf,
eps=self.config.template.eps, eps=self.config.template.eps,
**self.config.template.distogram, **self.config.template.distogram,
) ).to(z.dtype)
t = self.template_pair_embedder(t) t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t}) single_template_embeds.update({"pair": t})
...@@ -175,6 +175,12 @@ class AlphaFold(nn.Module): ...@@ -175,6 +175,12 @@ class AlphaFold(nn.Module):
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
# This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input # Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2] batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims) no_batch_dims = len(batch_dims)
...@@ -217,7 +223,9 @@ class AlphaFold(nn.Module): ...@@ -217,7 +223,9 @@ class AlphaFold(nn.Module):
requires_grad=False, requires_grad=False,
) )
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None) x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment