"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "7e473f194d52ffcfec577ed0f58fda16c1e1e093"
Unverified Commit 4d302345 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Fix multimer kernel bug (#103)

* fix multimer bug

* fix bug

* remove print

* fix index

* fix layernorm length

* update test

* improve template code

* remove print

* polish layernorm
parent 930a58ad
...@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -300,6 +300,7 @@ class TemplateEmbedderMultimer(nn.Module):
): ):
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
template_pair_embeddings = torch.zeros((z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
...@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -336,7 +337,7 @@ class TemplateEmbedderMultimer(nn.Module):
rigid_vec = rigid[..., None].inverse().apply_to_point(points) rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized() unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder( pair_embedding = self.template_pair_embedder(
template_dgram, template_dgram,
aatype_one_hot, aatype_one_hot,
z, z,
...@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -346,7 +347,23 @@ class TemplateEmbedderMultimer(nn.Module):
unit_vector, unit_vector,
) )
single_template_embeds["template_pair_embedding"] = pair_act if not inplace:
# [*, S_t, N, N, C_z]
template_pair_embeddings = template_pair_embeddings + self.template_pair_stack(
pair_embedding,
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
).squeeze(0)
else:
# [*, S_t, N, N, C_z]
template_pair_embeddings += self.template_pair_stack.inplace(
[pair_embedding],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].squeeze(0)
single_template_embeds.update( single_template_embeds.update(
self.template_single_embedder( self.template_single_embedder(
single_template_feats, single_template_feats,
...@@ -361,27 +378,10 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -361,27 +378,10 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds, template_embeds,
) )
if not inplace:
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
else:
template_embeds["template_pair_embedding"] = [template_embeds["template_pair_embedding"]]
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack.inplace(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].to(z.device)
# [*, N, N, C_z] # [*, N, N, C_z]
template_embeds["template_pair_embedding"] = torch.sum(template_embeds["template_pair_embedding"], dim=-4) / n_templ template_pair_embeddings = template_pair_embeddings / n_templ
template_embeds["template_pair_embedding"] = torch.nn.functional.relu(template_embeds["template_pair_embedding"]) template_pair_embeddings = torch.nn.functional.relu(template_pair_embeddings)
template_embeds["template_pair_embedding"] = self.linear_t(template_embeds["template_pair_embedding"]) template_pair_embeddings = self.linear_t(template_pair_embeddings)
template_embeds["template_pair_embedding"] = template_pair_embeddings
return template_embeds return template_embeds
...@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -34,6 +34,24 @@ class FusedLayerNorm(torch.nn.Module):
torch.nn.init.zeros_(self.bias) torch.nn.init.zeros_(self.bias)
def forward(self, input): def forward(self, input):
if len(input.shape) >= 3 and input.shape[-3] > 4000:
out = torch.empty_like(input)
# set max chunk_size = dim / 2, to max compute efficiency
chunk_size = min(4000 * 4000 // input.shape[-3], (input.shape[-3] + 1) // 2)
if len(input.shape) == 3:
for i in range(input.shape[-3]):
out[i:i + chunk_size] = self.kernel_forward(input[i:i + chunk_size])
elif len(input.shape) == 4:
for j in range(input.shape[-4]):
for i in range(0, input.shape[-3], chunk_size):
out[j, i:i + chunk_size] = self.kernel_forward(input[j, i:i + chunk_size])
else:
raise RuntimeError("Shape" + input.shape + "not implemented for layernorm yet!")
return out
else:
return self.kernel_forward(input)
def kernel_forward(self, input):
if _triton_available: if _triton_available:
return LayerNormTritonFunc.apply(input, self.normalized_shape, self.weight, self.bias, return LayerNormTritonFunc.apply(input, self.normalized_shape, self.weight, self.bias,
self.eps) self.eps)
......
...@@ -240,26 +240,18 @@ class TemplatePairBlock(nn.Module): ...@@ -240,26 +240,18 @@ class TemplatePairBlock(nn.Module):
z = scatter(z, dim=1) z = scatter(z, dim=1)
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size)) mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
single_mask_row = scatter(mask, dim=1)
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)] single_mask_col = scatter(mask, dim=2)
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for i in range(z.shape[0]): z = self.TriangleAttentionStartingNode(z, single_mask_row)
single = z[i].unsqueeze(-4) z = row_to_col(z)
single_mask = mask[i].unsqueeze(-3) z = self.TriangleAttentionEndingNode(z, single_mask_col)
z = col_to_row(z)
single_mask_row = scatter(single_mask, dim=1) z = self.TriangleMultiplicationOutgoing(z, single_mask_row)
single_mask_col = scatter(single_mask, dim=2) z = row_to_col(z)
z = self.TriangleMultiplicationIncoming(z, single_mask_col)
single = self.TriangleAttentionStartingNode(single, single_mask_row) z = self.PairTransition(z)
single = row_to_col(single) z = col_to_row(z)
single = self.TriangleAttentionEndingNode(single, single_mask_col)
single = col_to_row(single)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = self.PairTransition(single)
single = col_to_row(single)
z[i] = single
# z = torch.cat(single_templates, dim=-4) # z = torch.cat(single_templates, dim=-4)
if self.last_block: if self.last_block:
...@@ -275,10 +267,7 @@ class TemplatePairBlock(nn.Module): ...@@ -275,10 +267,7 @@ class TemplatePairBlock(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
): ):
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
z[0] = z[0].cpu()
dap_size = gpc.get_world_size(ParallelMode.TENSOR) dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = mask.size(-1) seq_length = mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
...@@ -287,31 +276,21 @@ class TemplatePairBlock(nn.Module): ...@@ -287,31 +276,21 @@ class TemplatePairBlock(nn.Module):
z[0] = scatter(z[0], dim=1) z[0] = scatter(z[0], dim=1)
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size)) mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
single_mask_row = scatter(mask, dim=1)
# single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)] single_mask_col = scatter(mask, dim=2)
# single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for i in range(z[0].shape[0]): z = self.TriangleAttentionStartingNode.inplace(z, single_mask_row)
single = z[0][i].unsqueeze(-4).to(mask.device) z[0] = row_to_col(z[0])
single_mask = mask[i].unsqueeze(-3) z = self.TriangleAttentionEndingNode.inplace(z, single_mask_col)
z[0] = col_to_row(z[0])
single_mask_row = scatter(single_mask, dim=1) z[0] = self.TriangleMultiplicationOutgoing(z[0], single_mask_row)
single_mask_col = scatter(single_mask, dim=2) z[0] = row_to_col(z[0])
z[0] = self.TriangleMultiplicationIncoming(z[0], single_mask_col)
single = self.TriangleAttentionStartingNode(single, single_mask_row) z = self.PairTransition.inplace(z)
single = row_to_col(single) z[0] = col_to_row(z[0])
single = self.TriangleAttentionEndingNode(single, single_mask_col)
single = col_to_row(single)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = self.PairTransition(single)
single = col_to_row(single)
z[0][i] = single.to(z[0].device)
# z = torch.cat(single_templates, dim=-4) # z = torch.cat(single_templates, dim=-4)
if self.last_block: if self.last_block:
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
z[0] = z[0].to(mask.device)
z[0] = gather(z[0], dim=1) z[0] = gather(z[0], dim=1)
z[0] = z[0][:, :-padding_size, :-padding_size, :] z[0] = z[0][:, :-padding_size, :-padding_size, :]
...@@ -408,15 +387,9 @@ class TemplatePairStack(nn.Module): ...@@ -408,15 +387,9 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if chunk_size is None:
chunk_size = t.shape[0] for i in range(0, t.shape[0]):
for i in range(0, t.shape[0], chunk_size): t[i] = self.layer_norm(t[i])
if t.shape[1] > 4000:
chunk_new = int(4000 * 4000 / t.shape[1])
for j in range(0, t.shape[1], chunk_new):
t[i:i + chunk_size, j:j + chunk_new] = self.layer_norm(t[i:i + chunk_size, j:j + chunk_new])
else:
t[i:i + chunk_size] = self.layer_norm(t[i:i + chunk_size])
return t return t
def inplace( def inplace(
...@@ -453,13 +426,7 @@ class TemplatePairStack(nn.Module): ...@@ -453,13 +426,7 @@ class TemplatePairStack(nn.Module):
args=(t,), args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
if chunk_size is None:
chunk_size = t[0].shape[0] for i in range(0, t[0].shape[0]):
for i in range(0, t[0].shape[0], chunk_size): t[0][i] = self.layer_norm(t[0][i].to(mask.device)).to(t[0].device)
if t[0].shape[1] > 4000:
chunk_new = int(4000 * 4000 / t[0].shape[1])
for j in range(0, t[0].shape[1], chunk_new):
t[0][i:i + chunk_size, j:j + chunk_new] = self.layer_norm(t[0][i:i + chunk_size, j:j + chunk_new].to(mask.device)).to(t[0].device)
else:
t[0][i:i + chunk_size] = self.layer_norm(t[0][i:i + chunk_size].to(mask.device)).to(t[0].device)
return t return t
...@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data ...@@ -73,4 +73,4 @@ def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data
m_fast = m_fast[:, :-padding_size, :] m_fast = m_fast[:, :-padding_size, :]
error = torch.max(torch.abs(m_out.cuda() - m_fast)) error = torch.max(torch.abs(m_out.cuda() - m_fast))
assert error < 5e-5, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}" assert error < 1e-4, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}"
...@@ -46,7 +46,7 @@ def get_openfold_module_and_data(): ...@@ -46,7 +46,7 @@ def get_openfold_module_and_data():
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32]) @pytest.mark.parametrize('chunk_size', [None, 4]) # should set 4 to test offload
@pytest.mark.parametrize('inplace', [False, True]) @pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data): def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(_test_template_embedder, world_size=world_size, chunk_size=chunk_size, run_func = partial(_test_template_embedder, world_size=world_size, chunk_size=chunk_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment