"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "6ba62cf25d07300d8bca2ea41d37a0823e4f96f8"
Unverified Commit 6fbc402e authored by LuGY's avatar LuGY Committed by GitHub
Browse files

refactor chunk (#117)

parent 3b096d67
...@@ -91,13 +91,12 @@ class ChunkTransition(nn.Module): ...@@ -91,13 +91,12 @@ class ChunkTransition(nn.Module):
self.linear2 = Linear(n * d, d, initializer='zeros') self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src): def forward(self, src):
para_dim = src.shape[1]
chunk_size = 48
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim out = self.norm(src)
out = self.linear2(F.relu(self.linear1(out)))
else: else:
chunk_size = CHUNK_SIZE * 48 chunk_size = CHUNK_SIZE * 48
para_dim = src.shape[1]
out = torch.empty_like(src) out = torch.empty_like(src)
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
if DEBUG and ax > 10: if DEBUG and ax > 10:
...@@ -155,11 +154,14 @@ class OutProductMean(nn.Module): ...@@ -155,11 +154,14 @@ class OutProductMean(nn.Module):
right_act_all = gather_async_opp(right_act_all, work, dim=2) right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all right_act_all = M_mask * right_act_all
if CHUNK_SIZE == None:
out = torch.einsum('bsid, bsje->bijde', left_act, right_act_all)
out = rearrange(out, 'b i j d e -> b i j (d e)')
out = self.o_linear(out)
Z = out / norm
else:
para_dim = left_act.shape[2] para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :] left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all) O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
...@@ -291,11 +293,6 @@ class SelfAttention(nn.Module): ...@@ -291,11 +293,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
""" """
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
if nonbatched_bias is not None: if nonbatched_bias is not None:
if nonbatched_bias[-1] == -1: if nonbatched_bias[-1] == -1:
bias = nonbatched_bias[0] bias = nonbatched_bias[0]
...@@ -304,6 +301,31 @@ class SelfAttention(nn.Module): ...@@ -304,6 +301,31 @@ class SelfAttention(nn.Module):
bias = gather_async_opp(*nonbatched_bias, dim=1) bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k') bias = rearrange(bias, 'b q k h -> b h q k')
if CHUNK_SIZE == None:
qkv = self.to_qkv(in_data).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
weights = fused_softmax(logits, mask, bias.unsqueeze(1))
else:
weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output = self.o_linear(weighted_avg)
else:
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
output = [] output = []
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
...@@ -981,16 +1003,16 @@ class ChunkMSAColumnGlobalAttention(nn.Module): ...@@ -981,16 +1003,16 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
) )
def forward(self, M_raw, M_mask): def forward(self, M_raw, M_mask):
para_dim = M_raw.shape[2]
if CHUNK_SIZE is None: if CHUNK_SIZE is None:
chunk_size = para_dim m = self.layernormM(M_raw.transpose(-2, -3))
m = self.global_attention(m, M_mask.transpose(-1, -2))
m = m.transpose(-2, -3)
M_raw = M_raw + m
else: else:
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
para_dim = M_raw.shape[2]
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3) m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
m = self.layernormM(m) m = self.layernormM(m)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2) m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
...@@ -1109,12 +1131,12 @@ class RecyclingEmbedder(nn.Module): ...@@ -1109,12 +1131,12 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins] # [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype) d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
para_dim = d.shape[1]
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
chunk_size = para_dim d = self.linear(d)
z = d + self.layer_norm_z(z)
else: else:
chunk_size = CHUNK_SIZE * 48 chunk_size = CHUNK_SIZE * 48
para_dim = d.shape[1]
for i in range(0, para_dim, chunk_size): for i in range(0, para_dim, chunk_size):
di = self.linear(d[i:i + chunk_size, :, :]) di = self.linear(d[i:i + chunk_size, :, :])
...@@ -1152,10 +1174,33 @@ class GlobalAttention(nn.Module): ...@@ -1152,10 +1174,33 @@ class GlobalAttention(nn.Module):
def forward(self, m, mask): def forward(self, m, mask):
if CHUNK_SIZE == None:
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
)
q = q * self.scaling
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m)
weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
)
m = self.o_linear(weighted_avg)
else:
para_dim = m.shape[1] para_dim = m.shape[1]
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
output = [] output = []
for ax in range(0, para_dim, chunk_size): for ax in range(0, para_dim, chunk_size):
...@@ -1190,7 +1235,6 @@ class GlobalAttention(nn.Module): ...@@ -1190,7 +1235,6 @@ class GlobalAttention(nn.Module):
return m return m
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
""" """
Embeds a subset of the input features. Embeds a subset of the input features.
......
...@@ -387,9 +387,11 @@ class TemplatePairStack(nn.Module): ...@@ -387,9 +387,11 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if not self.training:
for i in range(0, t.shape[0]): for i in range(0, t.shape[0]):
t[i] = self.layer_norm(t[i]) t[i] = self.layer_norm(t[i])
else:
t = self.layer_norm(t[i])
return t return t
def inplace( def inplace(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment