Commit 1df4991d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

DeepSpeed + PL bfloat16 working

parent 02fc4376
...@@ -119,7 +119,7 @@ class MSAAttention(nn.Module): ...@@ -119,7 +119,7 @@ class MSAAttention(nn.Module):
#bias = bias.expand( #bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#) #)
if (self.pair_bias and if (self.pair_bias and
z is not None and # For the z is not None and # For the
self.layer_norm_z is not None and # benefit of self.layer_norm_z is not None and # benefit of
...@@ -149,14 +149,14 @@ class MSAAttention(nn.Module): ...@@ -149,14 +149,14 @@ class MSAAttention(nn.Module):
def _get_qkv(m, z): def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask) m, mask_bias, z = self._prep_inputs(m, z, mask)
q, k, v = self.mha._prep_qkv(m, m) q, k, v = self.mha._prep_qkv(m, m)
return q, k, v, mask_bias, z return m, q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn() checkpoint_fn = get_checkpoint_fn()
if(checkpoint): if(torch.is_grad_enabled() and checkpoint):
q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else: else:
q, k, v, mask_bias, z = _get_qkv(m, z) m, q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable( o = _attention_chunked_trainable(
query=q, query=q,
...@@ -168,7 +168,7 @@ class MSAAttention(nn.Module): ...@@ -168,7 +168,7 @@ class MSAAttention(nn.Module):
checkpoint=checkpoint, checkpoint=checkpoint,
) )
if(checkpoint): if(torch.is_grad_enabled() and checkpoint):
# Storing an additional m here is far from ideal # Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m) m = checkpoint_fn(self.mha._wrap_up, o, m)
else: else:
......
...@@ -17,7 +17,7 @@ from typing import Optional ...@@ -17,7 +17,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.tensor_utils import chunk_layer
...@@ -40,7 +40,7 @@ class PairTransition(nn.Module): ...@@ -40,7 +40,7 @@ class PairTransition(nn.Module):
self.c_z = c_z self.c_z = c_z
self.n = n self.n = n
self.layer_norm = nn.LayerNorm(self.c_z) self.layer_norm = LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
......
...@@ -177,9 +177,9 @@ class LayerNorm(nn.Module): ...@@ -177,9 +177,9 @@ class LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(c_in)) self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in)) self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x): def forward(self, x):
d = x.dtype d = x.dtype
if(d == torch.bfloat16 and not deepspeed.utils.is_initialized()): if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm( out = nn.functional.layer_norm(
x, x,
...@@ -189,27 +189,34 @@ class LayerNorm(nn.Module): ...@@ -189,27 +189,34 @@ class LayerNorm(nn.Module):
self.eps self.eps
) )
elif(d == torch.bfloat16): elif(d == torch.bfloat16):
raise NotImplementedError out = nn.functional.layer_norm(
x,
self.c_in,
self.weight,
self.bias,
self.eps,
)
return out return out
@torch.jit.ignore
def softmax(t, dim=-1): def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
""" """
Softmax, but without automatic casting to fp32 when the input is of Softmax, but without automatic casting to fp32 when the input is of
type bfloat16 type bfloat16
""" """
d = t.dtype d = t.dtype
if(d == torch.bfloat16 and not deepspeed.utils.is_initialized()): if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim) s = torch.nn.functional.softmax(t, dim=dim)
elif(d == torch.bfloat16): elif(d == torch.bfloat16):
raise NotImplementedError s = torch.nn.functional.softmax(t, dim=dim)
return s return s
def _attention(query, key, value, biases): #@torch.jit.script
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2)) query = permute_final_dims(query, (1, 0, 2))
...@@ -218,14 +225,14 @@ def _attention(query, key, value, biases): ...@@ -218,14 +225,14 @@ def _attention(query, key, value, biases):
# [*, H, V, C_hidden] # [*, H, V, C_hidden]
value = permute_final_dims(value, (1, 0, 2)) value = permute_final_dims(value, (1, 0, 2))
# [*, H, Q, K] # [*, H, Q, K]
a = torch.matmul(query, key) a = torch.matmul(query, key)
for b in biases: for b in biases:
a += b a += b
a = softmax(a, dim=-1) a = softmax(a, -1)
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
a = torch.matmul(a, value) a = torch.matmul(a, value)
...@@ -354,7 +361,9 @@ class Attention(nn.Module): ...@@ -354,7 +361,9 @@ class Attention(nn.Module):
def _prep_qkv(self, def _prep_qkv(self,
q_x: torch.Tensor, q_x: torch.Tensor,
kv_x: torch.Tensor kv_x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor
]:
# [*, Q/K/V, H * C_hidden] # [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x) q = self.linear_q(q_x)
k = self.linear_k(kv_x) k = self.linear_k(kv_x)
...@@ -375,6 +384,7 @@ class Attention(nn.Module): ...@@ -375,6 +384,7 @@ class Attention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
if(self.linear_g is not None): if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x)) g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1)) g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g o = o * g
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple
from openfold.model.primitives import Linear, ipa_point_weights_init_ from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
...@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module): ...@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
) )
a = a * math.sqrt(1.0 / (3 * self.c_hidden)) a *= math.sqrt(1.0 / (3 * self.c_hidden))
a = a + (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
...@@ -323,7 +323,7 @@ class InvariantPointAttention(nn.Module): ...@@ -323,7 +323,7 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1)) pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att a = a + pt_att
a = a + square_mask.unsqueeze(-3) a = a + square_mask.unsqueeze(-3)
a = self.softmax(a) a = self.softmax(a)
...@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module): ...@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module):
# Compute output # Compute output
################ ################
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
o = torch.matmul(a, v.transpose(-2, -3)).transpose(-2, -3) o = torch.matmul(
a, v.transpose(-2, -3).to(dtype=a.dtype)
).transpose(-2, -3)
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2) o = flatten_final_dims(o, 2)
...@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module): ...@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module):
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
# [*, N_res, H, C_z] # [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z) o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
# [*, N_res, H * C_z] # [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2) o_pair = flatten_final_dims(o_pair, 2)
...@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module): ...@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out( s = self.linear_out(
torch.cat( torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
) ).to(dtype=z.dtype)
) )
return s return s
...@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module): ...@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module):
self.layers.append(l) self.layers.append(l)
self.dropout = nn.Dropout(self.dropout_rate) self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = nn.LayerNorm(self.c) self.layer_norm = LayerNorm(self.c)
def forward(self, s): def forward(self, s):
for l in self.layers: for l in self.layers:
...@@ -534,8 +536,8 @@ class StructureModule(nn.Module): ...@@ -534,8 +536,8 @@ class StructureModule(nn.Module):
self.atom_mask = None self.atom_mask = None
self.lit_positions = None self.lit_positions = None
self.layer_norm_s = nn.LayerNorm(self.c_s) self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
self.linear_in = Linear(self.c_s, self.c_s) self.linear_in = Linear(self.c_s, self.c_s)
...@@ -551,7 +553,7 @@ class StructureModule(nn.Module): ...@@ -551,7 +553,7 @@ class StructureModule(nn.Module):
) )
self.ipa_dropout = nn.Dropout(self.dropout_rate) self.ipa_dropout = nn.Dropout(self.dropout_rate)
self.layer_norm_ipa = nn.LayerNorm(self.c_s) self.layer_norm_ipa = LayerNorm(self.c_s)
self.transition = StructureModuleTransition( self.transition = StructureModuleTransition(
self.c_s, self.c_s,
......
...@@ -19,7 +19,7 @@ from typing import Optional, List ...@@ -19,7 +19,7 @@ from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, Attention from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
DropoutColumnwise, DropoutColumnwise,
...@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
mha_inputs = { mha_inputs = {
"q_x": z, "q_x": z,
"k_x": t, "kv_x": t,
"v_x": t,
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
...@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module):
if chunk_size is not None: if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size) z = self._chunk(z, t, biases, chunk_size)
else: else:
z = self.mha(q_x=z, k_x=t, v_x=t, biases=biases) z = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = z.squeeze(-2) z = z.squeeze(-2)
...@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module): ...@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module):
) )
self.blocks.append(block) self.blocks.append(block)
self.layer_norm = nn.LayerNorm(c_t) self.layer_norm = LayerNorm(c_t)
def forward( def forward(
self, self,
......
...@@ -20,7 +20,7 @@ from typing import Optional, List ...@@ -20,7 +20,7 @@ from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, Attention from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module): ...@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module):
self.starting = starting self.starting = starting
self.inf = inf self.inf = inf
self.layer_norm = nn.LayerNorm(self.c_in) self.layer_norm = LayerNorm(self.c_in)
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
...@@ -116,7 +116,7 @@ class TriangleAttention(nn.Module): ...@@ -116,7 +116,7 @@ class TriangleAttention(nn.Module):
if chunk_size is not None: if chunk_size is not None:
x = self._chunk(x, biases, chunk_size) x = self._chunk(x, biases, chunk_size)
else: else:
x = self.mha(q_x=x, k_x=x, v_x=x, biases=biases) x = self.mha(q_x=x, kv_x=x, biases=biases)
if not self.starting: if not self.starting:
x = x.transpose(-2, -3) x = x.transpose(-2, -3)
......
...@@ -19,7 +19,7 @@ from typing import Optional ...@@ -19,7 +19,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import permute_final_dims from openfold.utils.tensor_utils import permute_final_dims
...@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.linear_g = Linear(self.c_z, self.c_z, init="gating") self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final") self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = nn.LayerNorm(self.c_z) self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = nn.LayerNorm(self.c_hidden) self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
......
...@@ -26,7 +26,7 @@ def rot_matmul( ...@@ -26,7 +26,7 @@ def rot_matmul(
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Performs matrix multiplication of two rotation matrix tensors. Written Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid transfer to low-precision tensor cores. out by hand to avoid AMP downcasting.
Args: Args:
a: [*, 3, 3] left multiplicand a: [*, 3, 3] left multiplicand
...@@ -86,7 +86,7 @@ def rot_vec_mul( ...@@ -86,7 +86,7 @@ def rot_vec_mul(
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Applies a rotation to a vector. Written out by hand to avoid transfer Applies a rotation to a vector. Written out by hand to avoid transfer
to low-precision tensor cores. to avoid AMP downcasting.
Args: Args:
r: [*, 3, 3] rotation matrices r: [*, 3, 3] rotation matrices
...@@ -323,6 +323,12 @@ class Rotation: ...@@ -323,6 +323,12 @@ class Rotation:
"Incorrectly shaped rotation matrix or quaternion" "Incorrectly shaped rotation matrix or quaternion"
) )
# Force full-precision
if(quats is not None):
quats = quats.to(dtype=torch.float32)
if(rot_mats is not None):
rot_mats = rot_mats.to(dtype=torch.float32)
if(quats is not None and normalize_quats): if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
...@@ -857,6 +863,9 @@ class Rigid: ...@@ -857,6 +863,9 @@ class Rigid:
(rots.device != trans.device)): (rots.device != trans.device)):
raise ValueError("Rots and trans incompatible") raise ValueError("Rots and trans incompatible")
# Force full precision. Happens to the rotations automatically.
trans = trans.to(dtype=torch.float32)
self._rots = rots self._rots = rots
self._trans = trans self._trans = trans
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment