Unverified Commit 6fbc402e authored by LuGY's avatar LuGY Committed by GitHub
Browse files

refactor chunk (#117)

parent 3b096d67
...@@ -91,20 +91,19 @@ class ChunkTransition(nn.Module): ...@@ -91,20 +91,19 @@ class ChunkTransition(nn.Module):
self.linear2 = Linear(n * d, d, initializer='zeros') self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src): def forward(self, src):
para_dim = src.shape[1]
chunk_size = 48
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim out = self.norm(src)
out = self.linear2(F.relu(self.linear1(out)))
else: else:
chunk_size = CHUNK_SIZE * 48 chunk_size = CHUNK_SIZE * 48
para_dim = src.shape[1]
out = torch.empty_like(src) out = torch.empty_like(src)
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
if DEBUG and ax > 10: if DEBUG and ax > 10:
break break
x = self.norm(src[:, ax:ax + chunk_size, :, :]) x = self.norm(src[:, ax:ax + chunk_size, :, :])
x = self.linear2(F.relu(self.linear1(x))) x = self.linear2(F.relu(self.linear1(x)))
out[:, ax:ax + chunk_size, :, :] = x out[:, ax:ax + chunk_size, :, :] = x
out.add_(src) out.add_(src)
return out return out
...@@ -155,18 +154,21 @@ class OutProductMean(nn.Module): ...@@ -155,18 +154,21 @@ class OutProductMean(nn.Module):
right_act_all = gather_async_opp(right_act_all, work, dim=2) right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all right_act_all = M_mask * right_act_all
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim out = torch.einsum('bsid, bsje->bijde', left_act, right_act_all)
out = rearrange(out, 'b i j d e -> b i j (d e)')
for ax in range(0, para_dim, chunk_size): out = self.o_linear(out)
left_act_part = left_act[:, :, ax:ax + chunk_size, :] Z = out / norm
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all) else:
O = rearrange(O, 'b i j d e -> b i j (d e)') para_dim = left_act.shape[2]
O = self.o_linear(O) chunk_size = CHUNK_SIZE
norm0 = norm[:, ax:ax + chunk_size, :, :] for ax in range(0, para_dim, chunk_size):
Z[:, ax:ax + chunk_size, :, :] = O / norm0 left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :]
Z[:, ax:ax + chunk_size, :, :] = O / norm0
return Z + Z_raw return Z + Z_raw
...@@ -291,11 +293,6 @@ class SelfAttention(nn.Module): ...@@ -291,11 +293,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
""" """
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
if nonbatched_bias is not None: if nonbatched_bias is not None:
if nonbatched_bias[-1] == -1: if nonbatched_bias[-1] == -1:
bias = nonbatched_bias[0] bias = nonbatched_bias[0]
...@@ -303,14 +300,9 @@ class SelfAttention(nn.Module): ...@@ -303,14 +300,9 @@ class SelfAttention(nn.Module):
# logits += nonbatched_bias.unsqueeze(1) # logits += nonbatched_bias.unsqueeze(1)
bias = gather_async_opp(*nonbatched_bias, dim=1) bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k') bias = rearrange(bias, 'b q k h -> b h q k')
output = [] if CHUNK_SIZE == None:
for ax in range(0, para_dim, chunk_size): qkv = self.to_qkv(in_data).chunk(3, dim=-1)
in_data_part = in_data[:, ax:ax + chunk_size, :, :]
mask_part = mask[:, ax:ax + chunk_size, :]
qkv = self.to_qkv(in_data_part).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = q * self.scaling q = q * self.scaling
...@@ -318,25 +310,55 @@ class SelfAttention(nn.Module): ...@@ -318,25 +310,55 @@ class SelfAttention(nn.Module):
logits = torch.matmul(q, k.transpose(-1, -2)) logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None: if nonbatched_bias is not None:
# logits += bias.unsqueeze(1) weights = fused_softmax(logits, mask, bias.unsqueeze(1))
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part, bias.unsqueeze(1))
else: else:
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :] weights = fused_softmax(logits, mask)
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v) weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating: if self.gating:
gate_values = self.gating_linear(in_data_part) gate_values = self.gating_linear(in_data)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output.append(self.o_linear(weighted_avg)) output = self.o_linear(weighted_avg)
else:
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
output = []
for ax in range(0, para_dim, chunk_size):
in_data_part = in_data[:, ax:ax + chunk_size, :, :]
mask_part = mask[:, ax:ax + chunk_size, :]
qkv = self.to_qkv(in_data_part).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
output = torch.cat(output, dim=1) if nonbatched_bias is not None:
# logits += bias.unsqueeze(1)
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part, bias.unsqueeze(1))
else:
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data_part)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output.append(self.o_linear(weighted_avg))
output = torch.cat(output, dim=1)
return output return output
...@@ -981,22 +1003,22 @@ class ChunkMSAColumnGlobalAttention(nn.Module): ...@@ -981,22 +1003,22 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
) )
def forward(self, M_raw, M_mask): def forward(self, M_raw, M_mask):
para_dim = M_raw.shape[2]
if CHUNK_SIZE is None: if CHUNK_SIZE is None:
chunk_size = para_dim m = self.layernormM(M_raw.transpose(-2, -3))
m = self.global_attention(m, M_mask.transpose(-1, -2))
m = m.transpose(-2, -3)
M_raw = M_raw + m
else: else:
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
para_dim = M_raw.shape[2]
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10: m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
break m = self.layernormM(m)
m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3) m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
m = self.layernormM(m) m = self.global_attention(m, m_mask)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2) m = m.transpose(-2, -3)
m = self.global_attention(m, m_mask) M_raw[:, :, i:i + chunk_size, :] += m
m = m.transpose(-2, -3)
M_raw[:, :, i:i + chunk_size, :] += m
return M_raw return M_raw
...@@ -1109,16 +1131,16 @@ class RecyclingEmbedder(nn.Module): ...@@ -1109,16 +1131,16 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins] # [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype) d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
para_dim = d.shape[1]
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim d = self.linear(d)
z = d + self.layer_norm_z(z)
else: else:
chunk_size = CHUNK_SIZE * 48 chunk_size = CHUNK_SIZE * 48
para_dim = d.shape[1]
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
di = self.linear(d[i:i + chunk_size, :, :]) di = self.linear(d[i:i + chunk_size, :, :])
z[i:i + chunk_size, :, :] = di + self.layer_norm_z(z[i:i + chunk_size, :, :]) z[i:i + chunk_size, :, :] = di + self.layer_norm_z(z[i:i + chunk_size, :, :])
return m_update, z return m_update, z
...@@ -1152,44 +1174,66 @@ class GlobalAttention(nn.Module): ...@@ -1152,44 +1174,66 @@ class GlobalAttention(nn.Module):
def forward(self, m, mask): def forward(self, m, mask):
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
output = []
for ax in range(0, para_dim, chunk_size):
m_part = m[:, ax : ax + chunk_size, :, :]
mask_part = mask[:, ax : ax + chunk_size, :]
q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / (
torch.sum(mask_part, dim=-1)[..., None] + self.eps
) )
q = q * self.scaling q = q * self.scaling
q = self.to_q(q) q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1)) q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m_part).chunk(2, dim=-1) k, v = self.to_kv(m).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2)) logits = torch.matmul(q, k.transpose(-1, -2))
weights = fused_softmax(logits, mask_part) weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v) weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)") weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m_part) gate_values = self.gating_linear(m)
weighted_avg = bias_sigmod_ele( weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2) gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
) )
output.append(self.o_linear(weighted_avg)) m = self.o_linear(weighted_avg)
else:
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
m = torch.cat(output, dim=1) output = []
for ax in range(0, para_dim, chunk_size):
return m m_part = m[:, ax : ax + chunk_size, :, :]
mask_part = mask[:, ax : ax + chunk_size, :]
q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / (
torch.sum(mask_part, dim=-1)[..., None] + self.eps
)
q = q * self.scaling
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m_part).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m_part)
weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
)
output.append(self.o_linear(weighted_avg))
m = torch.cat(output, dim=1)
return m
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
""" """
......
...@@ -387,9 +387,11 @@ class TemplatePairStack(nn.Module): ...@@ -387,9 +387,11 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if not self.training:
for i in range(0, t.shape[0]): for i in range(0, t.shape[0]):
t[i] = self.layer_norm(t[i]) t[i] = self.layer_norm(t[i])
else:
t = self.layer_norm(t[i])
return t return t
def inplace( def inplace(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment