Commit 665e6c97 authored by shenggan's avatar shenggan Committed by Shenggan
Browse files

add chunk for self_att and opm (#38)

parent 48ae1b08
......@@ -96,7 +96,7 @@ c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
chunk_size = mlc.FieldReference(None, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
......
from .msa import MSAStack
from .ops import OutProductMean
from .ops import OutProductMean, set_chunk_size
from .triangle import PairStack
from .evoformer import Evoformer
__all__ = ['MSAStack', 'OutProductMean', 'PairStack', 'Evoformer']
__all__ = ['MSAStack', 'OutProductMean', 'PairStack', 'Evoformer', 'set_chunk_size']
......@@ -11,6 +11,14 @@ from fastfold.model.fastnn.kernel import bias_sigmod_ele
from fastfold.distributed import gather, scatter
from fastfold.distributed.comm_async import gather_async, gather_async_opp
CHUNK_SIZE = None
def set_chunk_size(chunk_size):
global CHUNK_SIZE
CHUNK_SIZE = chunk_size
class DropoutRowwise(nn.Module):
def __init__(self, p):
......@@ -81,9 +89,23 @@ class OutProductMean(nn.Module):
right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all
O = torch.einsum('bsid,bsje->bijde', left_act, right_act_all)
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
out = []
for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
O = rearrange(O, 'b i j d e -> b i j (d e)')
Z = self.o_linear(O)
out.append(self.o_linear(O))
Z = torch.cat(out, dim=1)
Z /= (1e-3 + norm)
......@@ -157,7 +179,23 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
qkv = self.to_qkv(in_data).chunk(3, dim=-1)
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
if nonbatched_bias is not None:
# logits += nonbatched_bias.unsqueeze(1)
bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k')
output = []
for ax in range(0, para_dim, chunk_size):
in_data_part = in_data[:, ax:ax + chunk_size, :, :]
mask_part = mask[:, ax:ax + chunk_size, :]
qkv = self.to_qkv(in_data_part).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = q * self.scaling
......@@ -165,10 +203,7 @@ class SelfAttention(nn.Module):
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
# logits += nonbatched_bias.unsqueeze(1)
bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k')
weights = mask_bias_softmax(logits, mask, bias.unsqueeze(1))
weights = mask_bias_softmax(logits, mask_part, bias.unsqueeze(1))
else:
weights = mask_softmax(logits, mask)
......@@ -176,8 +211,11 @@ class SelfAttention(nn.Module):
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data)
gate_values = self.gating_linear(in_data_part)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output = self.o_linear(weighted_avg)
output.append(self.o_linear(weighted_avg))
output = torch.cat(output, dim=1)
return output
......@@ -29,6 +29,7 @@ import fastfold
import fastfold.relax.relax as relax
from fastfold.common import protein, residue_constants
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.utils import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
......@@ -89,6 +90,8 @@ def inference_model(rank, world_size, result_q, batch, args):
model = model.eval()
model = model.cuda()
set_chunk_size(model.globals.chunk_size)
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment