Unverified Commit 4953b4ac authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

[autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
parent 67e1912b
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod, partial
import math
from typing import Optional, List
import torch
import torch.nn as nn
from .primitives import Linear, LayerNorm, Attention
from .tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
class TriangleAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, starting, inf=1e9
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Overall hidden channel dimension (not per-head)
no_heads:
Number of attention heads
"""
super(TriangleAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.starting = starting
self.inf = inf
self.layer_norm = LayerNorm(self.c_in)
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
)
@torch.jit.ignore
def _chunk(self,
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
) -> torch.Tensor:
mha_inputs = {
"q_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
partial(self.mha),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
)
def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
if mask is None:
# [*, I, J]
mask = x.new_ones(
x.shape[:-1],
)
# Shape annotations assume self.starting. Else, I and J are flipped
if not self.starting:
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, I, J, C_in]
x = self.layer_norm(x)
# [*, I, 1, 1, J]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
biases = [mask_bias, triangle_bias]
if chunk_size is not None:
x = self._chunk(x, biases, chunk_size)
else:
x = self.mha(q_x=x, kv_x=x, biases=biases)
if not self.starting:
x = x.transpose(-2, -3)
return x
class TriangleAttentionStartingNode(TriangleAttention):
"""
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
class TriangleAttentionEndingNode(TriangleAttention):
"""
Implements Algorithm 14.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod
from typing import Optional
import torch
import torch.nn as nn
from .primitives import Linear, LayerNorm
from .tensor_utils import permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
def _combine_projections(self,
a: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError("This method needs to be overridden")
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
a = a * mask
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
b = b * mask
x = self._combine_projections(a, b)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
z = x * g
return z
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
def _combine_projections(self,
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
def _combine_projections(self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerBlock
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward
model = model.cuda()
with torch.no_grad():
non_fx_out = model(node, pair, node_mask, pair_mask)
fx_out = gm(node, pair, node_mask, pair_mask)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _build_openfold():
model = EvoformerBlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
is_multimer=False,
).eval().cuda()
return model
def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = _build_openfold()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
interp = MetaInfoProp(meta_graph)
interp.propagate(
MetaTensor(node, fake_device="cuda:0"),
MetaTensor(pair, fake_device="cuda:0"),
MetaTensor(node_mask, fake_device="cuda:0"),
MetaTensor(pair_mask, fake_device="cuda:0"),
)
# codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
# graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
assert "chunk_size" in code
# print(code)
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@pytest.mark.skipif(
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_evoformer_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_evoformer_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_evoformer_codegen(0, 32, 64, 25)
......@@ -5,6 +5,12 @@ import torch
import torch.fx
import torch.multiprocessing as mp
try:
from simple_evoformer import base_evoformer
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
......@@ -13,7 +19,6 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from tests.test_autochunk.evoformer.evoformer import evoformer_base
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......@@ -48,7 +53,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
torch.abs(non_fx_out[1] - fx_out[1]))
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
......@@ -60,7 +65,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
)
# build model and input
model = evoformer_base().cuda()
model = base_evoformer().cuda()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
......@@ -95,13 +100,14 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
gpc.destroy()
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason='torch version is lower than 1.12.0')
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_codegen(msa_len, pair_len, max_memory):
def test_simple_evoformer_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_autochunk_codegen,
_test_simple_evoformer_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
......@@ -110,4 +116,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory):
if __name__ == "__main__":
_test_autochunk_codegen(0, 32, 64, 25)
_test_simple_evoformer_codegen(0, 32, 64, 25)
......@@ -5,13 +5,18 @@ import torch
import torch.fx
import torch.multiprocessing as mp
try:
from simple_evoformer import base_evoformer
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from tests.test_autochunk.evoformer.evoformer import evoformer_base
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......@@ -57,7 +62,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
)
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
......@@ -69,7 +74,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
)
# build model and input
model = evoformer_base().cuda()
model = base_evoformer().cuda()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
......@@ -84,13 +89,14 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
gpc.destroy()
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0")
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0")
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_search(msa_len, pair_len, max_memory):
def test_simple_evoformer_search(msa_len, pair_len, max_memory):
run_func = partial(
_test_autochunk_search,
_test_simple_evoformer_search,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
......@@ -99,4 +105,4 @@ def test_autochunk_search(msa_len, pair_len, max_memory):
if __name__ == "__main__":
_test_autochunk_search(0, 32, 64, 20)
_test_simple_evoformer_search(0, 32, 64, 20)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment