Unverified Commit 6fbc402e authored by LuGY's avatar LuGY Committed by GitHub
Browse files

refactor chunk (#117)

parent 3b096d67
......@@ -91,13 +91,12 @@ class ChunkTransition(nn.Module):
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
para_dim = src.shape[1]
chunk_size = 48
if CHUNK_SIZE == None:
chunk_size = para_dim
out = self.norm(src)
out = self.linear2(F.relu(self.linear1(out)))
else:
chunk_size = CHUNK_SIZE * 48
para_dim = src.shape[1]
out = torch.empty_like(src)
for ax in range(0, para_dim, chunk_size):
if DEBUG and ax > 10:
......@@ -155,11 +154,14 @@ class OutProductMean(nn.Module):
right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all
if CHUNK_SIZE == None:
out = torch.einsum('bsid, bsje->bijde', left_act, right_act_all)
out = rearrange(out, 'b i j d e -> b i j (d e)')
out = self.o_linear(out)
Z = out / norm
else:
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
......@@ -291,11 +293,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
if nonbatched_bias is not None:
if nonbatched_bias[-1] == -1:
bias = nonbatched_bias[0]
......@@ -304,6 +301,31 @@ class SelfAttention(nn.Module):
bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k')
if CHUNK_SIZE == None:
qkv = self.to_qkv(in_data).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
weights = fused_softmax(logits, mask, bias.unsqueeze(1))
else:
weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output = self.o_linear(weighted_avg)
else:
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
output = []
for ax in range(0, para_dim, chunk_size):
......@@ -981,16 +1003,16 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
)
def forward(self, M_raw, M_mask):
para_dim = M_raw.shape[2]
if CHUNK_SIZE is None:
chunk_size = para_dim
m = self.layernormM(M_raw.transpose(-2, -3))
m = self.global_attention(m, M_mask.transpose(-1, -2))
m = m.transpose(-2, -3)
M_raw = M_raw + m
else:
chunk_size = CHUNK_SIZE
para_dim = M_raw.shape[2]
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
m = self.layernormM(m)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
......@@ -1109,12 +1131,12 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
para_dim = d.shape[1]
if CHUNK_SIZE == None:
chunk_size = para_dim
d = self.linear(d)
z = d + self.layer_norm_z(z)
else:
chunk_size = CHUNK_SIZE * 48
para_dim = d.shape[1]
for i in range(0, para_dim, chunk_size):
di = self.linear(d[i:i + chunk_size, :, :])
......@@ -1152,10 +1174,33 @@ class GlobalAttention(nn.Module):
def forward(self, m, mask):
if CHUNK_SIZE == None:
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
)
q = q * self.scaling
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m)
weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
)
m = self.o_linear(weighted_avg)
else:
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
output = []
for ax in range(0, para_dim, chunk_size):
......@@ -1190,7 +1235,6 @@ class GlobalAttention(nn.Module):
return m
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
......
......@@ -387,9 +387,11 @@ class TemplatePairStack(nn.Module):
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
if not self.training:
for i in range(0, t.shape[0]):
t[i] = self.layer_norm(t[i])
else:
t = self.layer_norm(t[i])
return t
def inplace(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment