Unverified Commit 45b04fda authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

[hotfix] Fix multimer communication bug in evoformer (#105)

* fix multimer bug

* fix bug

* remove print

* fix index

* fix layernorm length

* update test

* improve template code

* remove print

* polish layernorm

* fix multimer communication in evoformer
parent 820b7a9a
......@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear
from fastfold.distributed.comm import gather, scatter
from fastfold.distributed.comm import gather, scatter, col_to_row
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.utils.checkpointing import checkpoint_blocks
......@@ -49,6 +49,9 @@ class Evoformer(nn.Module):
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1)
......@@ -76,6 +79,9 @@ class Evoformer(nn.Module):
m = m.squeeze(0)
z = z.squeeze(0)
if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0)
......@@ -106,6 +112,9 @@ class Evoformer(nn.Module):
m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size))
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m[0] = scatter(m[0], dim=2)
else:
m[0] = scatter(m[0], dim=1)
z[0] = scatter(z[0], dim=1)
......@@ -122,15 +131,8 @@ class Evoformer(nn.Module):
z = self.pair.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = col_to_row(m[0])
m[0] = self.msa(m[0], z[0], msa_mask)
z = self.pair.inplace(z, pair_mask)
......@@ -138,6 +140,9 @@ class Evoformer(nn.Module):
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
if self.is_multimer:
m[0] = gather(m[0], dim=1)
else:
m[0] = gather(m[0], dim=0)
z[0] = gather(z[0], dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment