Commit a8601529 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Prep for bfloat16 training

parent df89bb28
......@@ -148,7 +148,7 @@ config = mlc.ConfigDict(
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"max_extra_msa": 2048,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
......@@ -211,12 +211,12 @@ config = mlc.ConfigDict(
"fixed_size": True,
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_msa_clusters": 512,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"crop_size": 384,
"supervised": True,
"clamp_prob": 0.9,
"subsample_recycling": True,
......@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 8,
"num_workers": 1,
},
},
},
......@@ -340,7 +340,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
},
......
......@@ -185,7 +185,7 @@ class EvoformerBlockCore(nn.Module):
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
......@@ -229,7 +229,7 @@ class EvoformerBlock(nn.Module):
inf: float,
eps: float,
):
super().__init__()
super(EvoformerBlock, self).__init__()
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
......@@ -246,7 +246,6 @@ class EvoformerBlock(nn.Module):
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
......@@ -310,7 +309,7 @@ class ExtraMSABlock(nn.Module):
eps: float,
ckpt: bool,
):
super().__init__()
super(ExtraMSABlock, self).__init__()
self.ckpt = ckpt
......@@ -352,16 +351,16 @@ class ExtraMSABlock(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
checkpoint_chunk_size: Optional[int] = 512,
_chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]:
checkpoint_chunk_size = checkpoint_chunk_size if self.ckpt else None
m = m + self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=chunk_size,
_chunk_and_checkpoint=checkpoint_chunk_size,
_chunk_logits=_chunk_logits,
_checkpoint_chunks=self.ckpt,
)
)
......@@ -370,6 +369,7 @@ class ExtraMSABlock(nn.Module):
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
)
return m, z
if(self.ckpt):
......@@ -521,11 +521,8 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
s = s.squeeze(seq_dim)
s = self.linear(m[..., 0, :, :])
return m, z, s
......@@ -574,7 +571,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
ckpt=ckpt,
ckpt=False,
)
self.blocks.append(block)
......@@ -599,10 +596,27 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
checkpoint_fn = get_checkpoint_fn()
blocks = [
partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
]
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
def dodo(b, *args):
torch.cuda.empty_cache()
return b(*args)
blocks = [partial(dodo, b) for b in blocks]
for b in blocks:
if(torch.is_grad_enabled()):
m, z = checkpoint_fn(b, m, z)
else:
m, z = b(m, z)
#for b in self.blocks:
# m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
# if(self.clear_cache_between_blocks):
# torch.cuda.empty_cache()
return z
......@@ -16,9 +16,16 @@
import math
import torch
import torch.nn as nn
from typing import Optional, List
from openfold.model.primitives import Linear, Attention, GlobalAttention
from typing import Optional, List, Tuple
from openfold.model.primitives import (
Linear,
LayerNorm,
Attention,
GlobalAttention,
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -61,16 +68,16 @@ class MSAAttention(nn.Module):
self.c_z = c_z
self.inf = inf
self.layer_norm_m = nn.LayerNorm(self.c_in)
self.layer_norm_m = LayerNorm(self.c_in)
self.layer_norm_z = None
self.linear_z = None
if self.pair_bias:
self.layer_norm_z = nn.LayerNorm(self.c_z)
self.layer_norm_z = LayerNorm(self.c_z)
self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal"
)
self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
)
......@@ -83,33 +90,16 @@ class MSAAttention(nn.Module):
) -> torch.Tensor:
return chunk_layer(
self.mha,
{"q_x": m, "k_x": m, "v_x": m, "biases": biases},
{"q_x": m, "kv_x": m, "biases": biases},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_chunk_and_checkpoint: Optional[int] = None
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
......@@ -121,7 +111,7 @@ class MSAAttention(nn.Module):
)
# [*, N_seq, 1, 1, N_res]
bias = (self.inf * (mask - 1))[..., :, None, None, :]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
......@@ -129,9 +119,7 @@ class MSAAttention(nn.Module):
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
biases = [bias]
if (self.pair_bias and
z is not None and # For the
self.layer_norm_z is not None and # benefit of
......@@ -139,13 +127,88 @@ class MSAAttention(nn.Module):
):
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
# [*, N_res, N_res, no_heads]
z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
return m, mask_bias, z
@torch.jit.ignore
def _chunked_msa_attn(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
) -> torch.Tensor:
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
q, k, v = self.mha._prep_qkv(m, m)
return q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn()
if(checkpoint):
q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else:
q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable(
query=q,
key=k,
value=v,
biases=[mask_bias, z],
chunk_size=chunk_logits,
chunk_dim=MSA_DIM,
checkpoint=checkpoint,
)
if(checkpoint):
# Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m)
else:
m = self.mha._wrap_up(o, m)
return m
def forward(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None:
......@@ -153,10 +216,8 @@ class MSAAttention(nn.Module):
else:
m = self.mha(
q_x=m,
k_x=m,
v_x=m,
biases=biases,
_chunk_and_checkpoint=_chunk_and_checkpoint
kv_x=m,
biases=biases
)
return m
......
......@@ -18,6 +18,7 @@ import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
import deepspeed
import torch
import torch.nn as nn
from scipy.stats import truncnorm
......@@ -166,65 +167,126 @@ class Linear(nn.Linear):
raise ValueError("Invalid init string.")
class LayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super(LayerNorm, self).__init__()
self.c_in = (c_in,)
self.eps = eps
self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x):
d = x.dtype
if(d == torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight.to(dtype=d),
self.bias.to(dtype=d),
self.eps
)
elif(d == torch.bfloat16):
raise NotImplementedError
return out
def softmax(t, dim=-1):
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d = t.dtype
if(d == torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
elif(d == torch.bfloat16):
raise NotImplementedError
return s
def _attention(query, key, value, biases):
# [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2))
# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 2, 0))
# [*, H, V, C_hidden]
value = permute_final_dims(value, (1, 0, 2))
# [*, H, Q, K]
a = torch.matmul(query, key)
for b in biases:
a += b
a = torch.nn.functional.softmax(a, dim=-1)
a = softmax(a, dim=-1)
# [*, H, Q, C_hidden]
o = torch.matmul(a, value)
a = torch.matmul(a, value)
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
a = a.transpose(-2, -3)
return o
return a
@torch.jit.ignore
def _attention_chunk_and_checkpoint(query, key, value, biases, chunk_size):
if(len(biases) > 2):
def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint,
):
if(checkpoint and len(biases) > 2):
raise ValueError(
"_chunk_and_checkpoint only permits two bias terms"
"Checkpointed version permits only permits two bias terms"
)
biases = biases + [None, None]
bias_1, bias_2 = biases[:2]
def _checkpointable_attention(q, k, v, b1, b2):
bs = [b1, b2]
bs = [b for b in [b1, b2] if b is not None]
return _attention(q, k, v, bs)
batch_dims = query.shape[:-3]
no_batch_dims = len(query.shape[:-3])
# q, k, and v are assumed to have no singleton dimensions
flat_q = query.reshape(-1, *query.shape[-3:])
flat_k = key.reshape(-1, *key.shape[-3:])
flat_v = value.reshape(-1, *value.shape[-3:])
o_chunks = []
checkpoint_fn = get_checkpoint_fn()
count = flat_q.shape[0]
count = query.shape[chunk_dim]
for start in range(0, count, chunk_size):
end = start + chunk_size
q_chunk = flat_q[start: end, ...]
k_chunk = flat_k[start: end, ...]
v_chunk = flat_v[start: end, ...]
bias_1_chunk = _chunk_slice(bias_1, start, end, no_batch_dims)
bias_2_chunk = _chunk_slice(bias_2, start, end, no_batch_dims)
o_chunk = checkpoint_fn(_checkpointable_attention,
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
)
idx = [slice(None)] * len(query.shape)
idx[chunk_dim] = slice(start, end)
idx_tup = tuple(idx)
q_chunk = query[idx_tup]
k_chunk = key[idx_tup]
v_chunk = value[idx_tup]
def _slice_bias(b):
idx[chunk_dim] = (
slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
)
return b[tuple(idx)]
o_chunks.append(o_chunk)
if(checkpoint):
bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2]
]
o_chunk = checkpoint_fn(_checkpointable_attention,
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
)
else:
bias_chunks = [
_slice_bias(b) for b in biases
]
o_flat = torch.cat(o_chunks, dim=0)
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
return o_flat.reshape(batch_dims + o_flat.shape[1:])
o_chunks.append(o_chunk)
o = torch.cat(o_chunks, dim=chunk_dim)
return o
class Attention(nn.Module):
......@@ -289,16 +351,50 @@ class Attention(nn.Module):
self.sigmoid = nn.Sigmoid()
def _prep_qkv(self,
q_x: torch.Tensor,
kv_x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(kv_x)
v = self.linear_v(kv_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q /= math.sqrt(self.c_hidden)
return q, k, v
def _wrap_up(self,
o: torch.Tensor,
q_x: torch.Tensor
) -> torch.Tensor:
if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
def forward(
self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
_chunk_and_checkpoint: Optional[int] = None
) -> torch.Tensor:
"""
Args:
......@@ -318,59 +414,20 @@ class Attention(nn.Module):
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
if(use_lma and _chunk_and_checkpoint is not None):
raise ValueError(
"use_lma and _chunk_and_checkpoint are mutually exclusive"
)
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(self.c_hidden)
q, k, v = self._prep_qkv(q_x, kv_x)
if(use_lma):
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
else:
# [*, H, Q, C_hidden]
q = permute_final_dims(q, (1, 0, 2))
# [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, V, C_hidden]
v = permute_final_dims(v, (1, 0, 2))
if(_chunk_and_checkpoint):
# REMEMBER THAT THE K, Q, V COMPUTATION AND GATING ARE *NOT*
# CHECKPOINTED HERE
o = _attention_chunk_and_checkpoint(
q, k, v, biases, _chunk_and_checkpoint
)
else:
o = _attention(q, k, v, biases)
if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
o = _attention(q, k, v, biases)
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
o = self._wrap_up(o, q_x)
return o
......@@ -399,7 +456,6 @@ class GlobalAttention(nn.Module):
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in]
......@@ -425,7 +481,7 @@ class GlobalAttention(nn.Module):
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = self.softmax(a)
a = softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment