Commit 1df4991d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

DeepSpeed + PL bfloat16 working

parent 02fc4376
......@@ -149,14 +149,14 @@ class MSAAttention(nn.Module):
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
q, k, v = self.mha._prep_qkv(m, m)
return q, k, v, mask_bias, z
return m, q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn()
if(checkpoint):
q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
if(torch.is_grad_enabled() and checkpoint):
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else:
q, k, v, mask_bias, z = _get_qkv(m, z)
m, q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable(
query=q,
......@@ -168,7 +168,7 @@ class MSAAttention(nn.Module):
checkpoint=checkpoint,
)
if(checkpoint):
if(torch.is_grad_enabled() and checkpoint):
# Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m)
else:
......
......@@ -17,7 +17,7 @@ from typing import Optional
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import chunk_layer
......@@ -40,7 +40,7 @@ class PairTransition(nn.Module):
self.c_z = c_z
self.n = n
self.layer_norm = nn.LayerNorm(self.c_z)
self.layer_norm = LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
......
......@@ -179,7 +179,7 @@ class LayerNorm(nn.Module):
def forward(self, x):
d = x.dtype
if(d == torch.bfloat16 and not deepspeed.utils.is_initialized()):
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
......@@ -189,27 +189,34 @@ class LayerNorm(nn.Module):
self.eps
)
elif(d == torch.bfloat16):
raise NotImplementedError
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight,
self.bias,
self.eps,
)
return out
def softmax(t, dim=-1):
@torch.jit.ignore
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d = t.dtype
if(d == torch.bfloat16 and not deepspeed.utils.is_initialized()):
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
elif(d == torch.bfloat16):
raise NotImplementedError
s = torch.nn.functional.softmax(t, dim=dim)
return s
def _attention(query, key, value, biases):
#@torch.jit.script
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
# [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2))
......@@ -225,7 +232,7 @@ def _attention(query, key, value, biases):
for b in biases:
a += b
a = softmax(a, dim=-1)
a = softmax(a, -1)
# [*, H, Q, C_hidden]
a = torch.matmul(a, value)
......@@ -354,7 +361,9 @@ class Attention(nn.Module):
def _prep_qkv(self,
q_x: torch.Tensor,
kv_x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor
]:
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(kv_x)
......@@ -375,6 +384,7 @@ class Attention(nn.Module):
) -> torch.Tensor:
if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
......
......@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from typing import Optional, Tuple
from openfold.model.primitives import Linear, ipa_point_weights_init_
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
......@@ -298,8 +298,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a = a * math.sqrt(1.0 / (3 * self.c_hidden))
a = a + (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
......@@ -331,7 +331,9 @@ class InvariantPointAttention(nn.Module):
# Compute output
################
# [*, N_res, H, C_hidden]
o = torch.matmul(a, v.transpose(-2, -3)).transpose(-2, -3)
o = torch.matmul(
a, v.transpose(-2, -3).to(dtype=a.dtype)
).transpose(-2, -3)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
......@@ -360,7 +362,7 @@ class InvariantPointAttention(nn.Module):
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z)
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
......@@ -369,7 +371,7 @@ class InvariantPointAttention(nn.Module):
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
)
).to(dtype=z.dtype)
)
return s
......@@ -444,7 +446,7 @@ class StructureModuleTransition(nn.Module):
self.layers.append(l)
self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = nn.LayerNorm(self.c)
self.layer_norm = LayerNorm(self.c)
def forward(self, s):
for l in self.layers:
......@@ -534,8 +536,8 @@ class StructureModule(nn.Module):
self.atom_mask = None
self.lit_positions = None
self.layer_norm_s = nn.LayerNorm(self.c_s)
self.layer_norm_z = nn.LayerNorm(self.c_z)
self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = LayerNorm(self.c_z)
self.linear_in = Linear(self.c_s, self.c_s)
......@@ -551,7 +553,7 @@ class StructureModule(nn.Module):
)
self.ipa_dropout = nn.Dropout(self.dropout_rate)
self.layer_norm_ipa = nn.LayerNorm(self.c_s)
self.layer_norm_ipa = LayerNorm(self.c_s)
self.transition = StructureModuleTransition(
self.c_s,
......
......@@ -19,7 +19,7 @@ from typing import Optional, List
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.model.dropout import (
DropoutRowwise,
DropoutColumnwise,
......@@ -80,8 +80,7 @@ class TemplatePointwiseAttention(nn.Module):
) -> torch.Tensor:
mha_inputs = {
"q_x": z,
"k_x": t,
"v_x": t,
"kv_x": t,
"biases": biases,
}
return chunk_layer(
......@@ -125,7 +124,7 @@ class TemplatePointwiseAttention(nn.Module):
if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size)
else:
z = self.mha(q_x=z, k_x=t, v_x=t, biases=biases)
z = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z]
z = z.squeeze(-2)
......@@ -292,7 +291,7 @@ class TemplatePairStack(nn.Module):
)
self.blocks.append(block)
self.layer_norm = nn.LayerNorm(c_t)
self.layer_norm = LayerNorm(c_t)
def forward(
self,
......
......@@ -20,7 +20,7 @@ from typing import Optional, List
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.model.primitives import Linear, LayerNorm, Attention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -49,7 +49,7 @@ class TriangleAttention(nn.Module):
self.starting = starting
self.inf = inf
self.layer_norm = nn.LayerNorm(self.c_in)
self.layer_norm = LayerNorm(self.c_in)
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
......@@ -116,7 +116,7 @@ class TriangleAttention(nn.Module):
if chunk_size is not None:
x = self._chunk(x, biases, chunk_size)
else:
x = self.mha(q_x=x, k_x=x, v_x=x, biases=biases)
x = self.mha(q_x=x, kv_x=x, biases=biases)
if not self.starting:
x = x.transpose(-2, -3)
......
......@@ -19,7 +19,7 @@ from typing import Optional
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import permute_final_dims
......@@ -47,8 +47,8 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = nn.LayerNorm(self.c_z)
self.layer_norm_out = nn.LayerNorm(self.c_hidden)
self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
......
......@@ -26,7 +26,7 @@ def rot_matmul(
) -> torch.Tensor:
"""
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid transfer to low-precision tensor cores.
out by hand to avoid AMP downcasting.
Args:
a: [*, 3, 3] left multiplicand
......@@ -86,7 +86,7 @@ def rot_vec_mul(
) -> torch.Tensor:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
to low-precision tensor cores.
to avoid AMP downcasting.
Args:
r: [*, 3, 3] rotation matrices
......@@ -323,6 +323,12 @@ class Rotation:
"Incorrectly shaped rotation matrix or quaternion"
)
# Force full-precision
if(quats is not None):
quats = quats.to(dtype=torch.float32)
if(rot_mats is not None):
rot_mats = rot_mats.to(dtype=torch.float32)
if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
......@@ -857,6 +863,9 @@ class Rigid:
(rots.device != trans.device)):
raise ValueError("Rots and trans incompatible")
# Force full precision. Happens to the rotations automatically.
trans = trans.to(dtype=torch.float32)
self._rots = rots
self._trans = trans
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment