Commit a8601529 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Prep for bfloat16 training

parent df89bb28
...@@ -148,7 +148,7 @@ config = mlc.ConfigDict( ...@@ -148,7 +148,7 @@ config = mlc.ConfigDict(
"same_prob": 0.1, "same_prob": 0.1,
"uniform_prob": 0.1, "uniform_prob": 0.1,
}, },
"max_extra_msa": 1024, "max_extra_msa": 2048,
"max_recycling_iters": 3, "max_recycling_iters": 3,
"msa_cluster_features": True, "msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False, "reduce_msa_clusters_by_max_templates": False,
...@@ -211,12 +211,12 @@ config = mlc.ConfigDict( ...@@ -211,12 +211,12 @@ config = mlc.ConfigDict(
"fixed_size": True, "fixed_size": True,
"subsample_templates": True, "subsample_templates": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 512,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
"crop": True, "crop": True,
"crop_size": 256, "crop_size": 384,
"supervised": True, "supervised": True,
"clamp_prob": 0.9, "clamp_prob": 0.9,
"subsample_recycling": True, "subsample_recycling": True,
...@@ -226,7 +226,7 @@ config = mlc.ConfigDict( ...@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False, "use_small_bfd": False,
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 8, "num_workers": 1,
}, },
}, },
}, },
...@@ -340,7 +340,7 @@ config = mlc.ConfigDict( ...@@ -340,7 +340,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": True,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
......
...@@ -185,7 +185,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -185,7 +185,7 @@ class EvoformerBlockCore(nn.Module):
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of # should be disabled to better approximate the exact activations of
# the original. # the original.
...@@ -229,7 +229,7 @@ class EvoformerBlock(nn.Module): ...@@ -229,7 +229,7 @@ class EvoformerBlock(nn.Module):
inf: float, inf: float,
eps: float, eps: float,
): ):
super().__init__() super(EvoformerBlock, self).__init__()
self.msa_att_row = MSARowAttentionWithPairBias( self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m, c_m=c_m,
...@@ -246,7 +246,6 @@ class EvoformerBlock(nn.Module): ...@@ -246,7 +246,6 @@ class EvoformerBlock(nn.Module):
inf=inf, inf=inf,
) )
self.msa_dropout_layer = DropoutRowwise(msa_dropout) self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore( self.core = EvoformerBlockCore(
...@@ -310,7 +309,7 @@ class ExtraMSABlock(nn.Module): ...@@ -310,7 +309,7 @@ class ExtraMSABlock(nn.Module):
eps: float, eps: float,
ckpt: bool, ckpt: bool,
): ):
super().__init__() super(ExtraMSABlock, self).__init__()
self.ckpt = ckpt self.ckpt = ckpt
...@@ -352,16 +351,16 @@ class ExtraMSABlock(nn.Module): ...@@ -352,16 +351,16 @@ class ExtraMSABlock(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
checkpoint_chunk_size: Optional[int] = 512, _chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
checkpoint_chunk_size = checkpoint_chunk_size if self.ckpt else None
m = m + self.msa_dropout_layer( m = m + self.msa_dropout_layer(
self.msa_att_row( self.msa_att_row(
m, m,
z=z, z=z,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
_chunk_and_checkpoint=checkpoint_chunk_size, _chunk_logits=_chunk_logits,
_checkpoint_chunks=self.ckpt,
) )
) )
...@@ -370,6 +369,7 @@ class ExtraMSABlock(nn.Module): ...@@ -370,6 +369,7 @@ class ExtraMSABlock(nn.Module):
m, z = self.core( m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
) )
return m, z return m, z
if(self.ckpt): if(self.ckpt):
...@@ -521,11 +521,8 @@ class EvoformerStack(nn.Module): ...@@ -521,11 +521,8 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
seq_dim = -3 s = self.linear(m[..., 0, :, :])
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
s = s.squeeze(seq_dim)
return m, z, s return m, z, s
...@@ -574,7 +571,7 @@ class ExtraMSAStack(nn.Module): ...@@ -574,7 +571,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
inf=inf, inf=inf,
eps=eps, eps=eps,
ckpt=ckpt, ckpt=False,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -599,10 +596,27 @@ class ExtraMSAStack(nn.Module): ...@@ -599,10 +596,27 @@ class ExtraMSAStack(nn.Module):
Returns: Returns:
[*, N_res, N_res, C_z] pair update [*, N_res, N_res, C_z] pair update
""" """
for b in self.blocks: checkpoint_fn = get_checkpoint_fn()
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) blocks = [
partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
]
if(self.clear_cache_between_blocks): def dodo(b, *args):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return b(*args)
blocks = [partial(dodo, b) for b in blocks]
for b in blocks:
if(torch.is_grad_enabled()):
m, z = checkpoint_fn(b, m, z)
else:
m, z = b(m, z)
#for b in self.blocks:
# m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
# if(self.clear_cache_between_blocks):
# torch.cuda.empty_cache()
return z return z
...@@ -16,9 +16,16 @@ ...@@ -16,9 +16,16 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, List from typing import Optional, List, Tuple
from openfold.model.primitives import Linear, Attention, GlobalAttention from openfold.model.primitives import (
Linear,
LayerNorm,
Attention,
GlobalAttention,
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -61,16 +68,16 @@ class MSAAttention(nn.Module): ...@@ -61,16 +68,16 @@ class MSAAttention(nn.Module):
self.c_z = c_z self.c_z = c_z
self.inf = inf self.inf = inf
self.layer_norm_m = nn.LayerNorm(self.c_in) self.layer_norm_m = LayerNorm(self.c_in)
self.layer_norm_z = None self.layer_norm_z = None
self.linear_z = None self.linear_z = None
if self.pair_bias: if self.pair_bias:
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
self.linear_z = Linear( self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal" self.c_z, self.no_heads, bias=False, init="normal"
) )
self.mha = Attention( self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
) )
...@@ -83,33 +90,16 @@ class MSAAttention(nn.Module): ...@@ -83,33 +90,16 @@ class MSAAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
return chunk_layer( return chunk_layer(
self.mha, self.mha,
{"q_x": m, "k_x": m, "v_x": m, "biases": biases}, {"q_x": m, "kv_x": m, "biases": biases},
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
def forward(self, def _prep_inputs(self,
m: torch.Tensor, m: torch.Tensor,
z: Optional[torch.Tensor] = None, z: Optional[torch.Tensor],
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor]
chunk_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
_chunk_and_checkpoint: Optional[int] = None
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_seq, N_res, C_m] # [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m) m = self.layer_norm_m(m)
...@@ -121,7 +111,7 @@ class MSAAttention(nn.Module): ...@@ -121,7 +111,7 @@ class MSAAttention(nn.Module):
) )
# [*, N_seq, 1, 1, N_res] # [*, N_seq, 1, 1, N_res]
bias = (self.inf * (mask - 1))[..., :, None, None, :] mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# This step simply returns a larger view of the bias, and does not # This step simply returns a larger view of the bias, and does not
# consume additional memory. # consume additional memory.
...@@ -129,9 +119,7 @@ class MSAAttention(nn.Module): ...@@ -129,9 +119,7 @@ class MSAAttention(nn.Module):
#bias = bias.expand( #bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#) #)
biases = [bias]
if (self.pair_bias and if (self.pair_bias and
z is not None and # For the z is not None and # For the
self.layer_norm_z is not None and # benefit of self.layer_norm_z is not None and # benefit of
...@@ -139,13 +127,88 @@ class MSAAttention(nn.Module): ...@@ -139,13 +127,88 @@ class MSAAttention(nn.Module):
): ):
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(z)
# [*, N_res, N_res, no_heads] # [*, N_res, N_res, no_heads]
z = self.linear_z(z) z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
return m, mask_bias, z
@torch.jit.ignore
def _chunked_msa_attn(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
) -> torch.Tensor:
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
q, k, v = self.mha._prep_qkv(m, m)
return q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn()
if(checkpoint):
q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else:
q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable(
query=q,
key=k,
value=v,
biases=[mask_bias, z],
chunk_size=chunk_logits,
chunk_dim=MSA_DIM,
checkpoint=checkpoint,
)
if(checkpoint):
# Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m)
else:
m = self.mha._wrap_up(o, m)
return m
def forward(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
biases = [mask_bias]
if(z is not None):
biases.append(z) biases.append(z)
if chunk_size is not None: if chunk_size is not None:
...@@ -153,10 +216,8 @@ class MSAAttention(nn.Module): ...@@ -153,10 +216,8 @@ class MSAAttention(nn.Module):
else: else:
m = self.mha( m = self.mha(
q_x=m, q_x=m,
k_x=m, kv_x=m,
v_x=m, biases=biases
biases=biases,
_chunk_and_checkpoint=_chunk_and_checkpoint
) )
return m return m
......
...@@ -18,6 +18,7 @@ import math ...@@ -18,6 +18,7 @@ import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
import deepspeed
import torch import torch
import torch.nn as nn import torch.nn as nn
from scipy.stats import truncnorm from scipy.stats import truncnorm
...@@ -166,65 +167,126 @@ class Linear(nn.Linear): ...@@ -166,65 +167,126 @@ class Linear(nn.Linear):
raise ValueError("Invalid init string.") raise ValueError("Invalid init string.")
class LayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super(LayerNorm, self).__init__()
self.c_in = (c_in,)
self.eps = eps
self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x):
d = x.dtype
if(d == torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight.to(dtype=d),
self.bias.to(dtype=d),
self.eps
)
elif(d == torch.bfloat16):
raise NotImplementedError
return out
def softmax(t, dim=-1):
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d = t.dtype
if(d == torch.bfloat16 and not deepspeed.utils.is_initialized()):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
elif(d == torch.bfloat16):
raise NotImplementedError
return s
def _attention(query, key, value, biases): def _attention(query, key, value, biases):
# [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2))
# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 2, 0))
# [*, H, V, C_hidden]
value = permute_final_dims(value, (1, 0, 2))
# [*, H, Q, K]
a = torch.matmul(query, key) a = torch.matmul(query, key)
for b in biases: for b in biases:
a += b a += b
a = torch.nn.functional.softmax(a, dim=-1) a = softmax(a, dim=-1)
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
o = torch.matmul(a, value) a = torch.matmul(a, value)
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
o = o.transpose(-2, -3) a = a.transpose(-2, -3)
return o return a
@torch.jit.ignore @torch.jit.ignore
def _attention_chunk_and_checkpoint(query, key, value, biases, chunk_size): def _attention_chunked_trainable(
if(len(biases) > 2): query, key, value, biases, chunk_size, chunk_dim, checkpoint,
):
if(checkpoint and len(biases) > 2):
raise ValueError( raise ValueError(
"_chunk_and_checkpoint only permits two bias terms" "Checkpointed version permits only permits two bias terms"
) )
biases = biases + [None, None]
bias_1, bias_2 = biases[:2]
def _checkpointable_attention(q, k, v, b1, b2): def _checkpointable_attention(q, k, v, b1, b2):
bs = [b1, b2] bs = [b for b in [b1, b2] if b is not None]
return _attention(q, k, v, bs) return _attention(q, k, v, bs)
batch_dims = query.shape[:-3]
no_batch_dims = len(query.shape[:-3])
# q, k, and v are assumed to have no singleton dimensions
flat_q = query.reshape(-1, *query.shape[-3:])
flat_k = key.reshape(-1, *key.shape[-3:])
flat_v = value.reshape(-1, *value.shape[-3:])
o_chunks = [] o_chunks = []
checkpoint_fn = get_checkpoint_fn() checkpoint_fn = get_checkpoint_fn()
count = flat_q.shape[0] count = query.shape[chunk_dim]
for start in range(0, count, chunk_size): for start in range(0, count, chunk_size):
end = start + chunk_size end = start + chunk_size
q_chunk = flat_q[start: end, ...] idx = [slice(None)] * len(query.shape)
k_chunk = flat_k[start: end, ...] idx[chunk_dim] = slice(start, end)
v_chunk = flat_v[start: end, ...] idx_tup = tuple(idx)
bias_1_chunk = _chunk_slice(bias_1, start, end, no_batch_dims) q_chunk = query[idx_tup]
bias_2_chunk = _chunk_slice(bias_2, start, end, no_batch_dims) k_chunk = key[idx_tup]
v_chunk = value[idx_tup]
o_chunk = checkpoint_fn(_checkpointable_attention,
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk def _slice_bias(b):
) idx[chunk_dim] = (
slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
)
return b[tuple(idx)]
o_chunks.append(o_chunk) if(checkpoint):
bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2]
]
o_chunk = checkpoint_fn(_checkpointable_attention,
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
)
else:
bias_chunks = [
_slice_bias(b) for b in biases
]
o_flat = torch.cat(o_chunks, dim=0) o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
return o_flat.reshape(batch_dims + o_flat.shape[1:]) o_chunks.append(o_chunk)
o = torch.cat(o_chunks, dim=chunk_dim)
return o
class Attention(nn.Module): class Attention(nn.Module):
...@@ -289,16 +351,50 @@ class Attention(nn.Module): ...@@ -289,16 +351,50 @@ class Attention(nn.Module):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def _prep_qkv(self,
q_x: torch.Tensor,
kv_x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(kv_x)
v = self.linear_v(kv_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q /= math.sqrt(self.c_hidden)
return q, k, v
def _wrap_up(self,
o: torch.Tensor,
q_x: torch.Tensor
) -> torch.Tensor:
if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
def forward( def forward(
self, self,
q_x: torch.Tensor, q_x: torch.Tensor,
k_x: torch.Tensor, kv_x: torch.Tensor,
v_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_lma: bool = False, use_lma: bool = False,
q_chunk_size: Optional[int] = None, q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None, kv_chunk_size: Optional[int] = None,
_chunk_and_checkpoint: Optional[int] = None
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -318,59 +414,20 @@ class Attention(nn.Module): ...@@ -318,59 +414,20 @@ class Attention(nn.Module):
"If use_lma is specified, q_chunk_size and kv_chunk_size must " "If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided" "be provided"
) )
if(use_lma and _chunk_and_checkpoint is not None):
raise ValueError(
"use_lma and _chunk_and_checkpoint are mutually exclusive"
)
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden] q, k, v = self._prep_qkv(q_x, kv_x)
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q = q / math.sqrt(self.c_hidden)
if(use_lma): if(use_lma):
biases = [ biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (k_x.shape[-2],))
for b in biases for b in biases
] ]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
else: else:
# [*, H, Q, C_hidden] o = _attention(q, k, v, biases)
q = permute_final_dims(q, (1, 0, 2))
# [*, H, C_hidden, K]
k = permute_final_dims(k, (1, 2, 0))
# [*, H, V, C_hidden]
v = permute_final_dims(v, (1, 0, 2))
if(_chunk_and_checkpoint):
# REMEMBER THAT THE K, Q, V COMPUTATION AND GATING ARE *NOT*
# CHECKPOINTED HERE
o = _attention_chunk_and_checkpoint(
q, k, v, biases, _chunk_and_checkpoint
)
else:
o = _attention(q, k, v, biases)
if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden] o = self._wrap_up(o, q_x)
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o return o
...@@ -399,7 +456,6 @@ class GlobalAttention(nn.Module): ...@@ -399,7 +456,6 @@ class GlobalAttention(nn.Module):
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in] # [*, N_res, C_in]
...@@ -425,7 +481,7 @@ class GlobalAttention(nn.Module): ...@@ -425,7 +481,7 @@ class GlobalAttention(nn.Module):
) )
bias = (self.inf * (mask - 1))[..., :, None, :] bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias a += bias
a = self.softmax(a) a = softmax(a)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
o = torch.matmul( o = torch.matmul(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment