"tests/git@developer.sourcefind.cn:chenpangpang/diffusers.git" did not exist on "11b3002b48353b33880e385c576888ca5405918a"
Unverified Commit 45b04fda authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

[hotfix] Fix multimer communication bug in evoformer (#105)

* fix multimer bug

* fix bug

* remove print

* fix index

* fix layernorm length

* update test

* improve template code

* remove print

* polish layernorm

* fix multimer communication in evoformer
parent 820b7a9a
...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc ...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSACore, OutProductMean, PairCore from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear from fastfold.model.fastnn.ops import Linear
from fastfold.distributed.comm import gather, scatter from fastfold.distributed.comm import gather, scatter, col_to_row
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.utils.checkpointing import checkpoint_blocks from fastfold.utils.checkpointing import checkpoint_blocks
...@@ -49,7 +49,10 @@ class Evoformer(nn.Module): ...@@ -49,7 +49,10 @@ class Evoformer(nn.Module):
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size)) m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size)) z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
m = scatter(m, dim=1) if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1) z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0) msa_mask = msa_mask.unsqueeze(0)
...@@ -76,7 +79,10 @@ class Evoformer(nn.Module): ...@@ -76,7 +79,10 @@ class Evoformer(nn.Module):
m = m.squeeze(0) m = m.squeeze(0)
z = z.squeeze(0) z = z.squeeze(0)
m = gather(m, dim=0) if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0) z = gather(z, dim=0)
m = m[:, :-padding_size, :] m = m[:, :-padding_size, :]
...@@ -106,7 +112,10 @@ class Evoformer(nn.Module): ...@@ -106,7 +112,10 @@ class Evoformer(nn.Module):
m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size)) m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size))
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size)) z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
m[0] = scatter(m[0], dim=1) if self.is_multimer:
m[0] = scatter(m[0], dim=2)
else:
m[0] = scatter(m[0], dim=1)
z[0] = scatter(z[0], dim=1) z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0) msa_mask = msa_mask.unsqueeze(0)
...@@ -122,15 +131,8 @@ class Evoformer(nn.Module): ...@@ -122,15 +131,8 @@ class Evoformer(nn.Module):
z = self.pair.inplace(z, pair_mask) z = self.pair.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2) m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else: else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z) z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2) m[0] = col_to_row(m[0])
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = self.msa(m[0], z[0], msa_mask) m[0] = self.msa(m[0], z[0], msa_mask)
z = self.pair.inplace(z, pair_mask) z = self.pair.inplace(z, pair_mask)
...@@ -138,7 +140,10 @@ class Evoformer(nn.Module): ...@@ -138,7 +140,10 @@ class Evoformer(nn.Module):
m[0] = m[0].squeeze(0) m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0) z[0] = z[0].squeeze(0)
m[0] = gather(m[0], dim=0) if self.is_multimer:
m[0] = gather(m[0], dim=1)
else:
m[0] = gather(m[0], dim=0)
z[0] = gather(z[0], dim=0) z[0] = gather(z[0], dim=0)
m[0] = m[0][:, :-padding_size, :] m[0] = m[0][:, :-padding_size, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment