"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "ba47517342ca8ae98986878ead3decb00c28a37a"
Commit efb1c64c authored by oahzxl's avatar oahzxl
Browse files

restruct dir

parent 27ab5240
...@@ -3,13 +3,13 @@ import time ...@@ -3,13 +3,13 @@ import time
import torch import torch
import torch.fx import torch.fx
from autochunk.chunk_codegen import ChunkCodeGen from colossalai.autochunk.chunk_codegen import ChunkCodeGen
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from autochunk.evoformer.evoformer import evoformer_base from tests.test_autochunk.evoformer.evoformer import evoformer_base
from autochunk.openfold.evoformer import EvoformerBlock from tests.test_autochunk.openfold.evoformer import EvoformerBlock
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
...@@ -94,7 +94,7 @@ def _build_openfold(): ...@@ -94,7 +94,7 @@ def _build_openfold():
def benchmark_evoformer(): def benchmark_evoformer():
# init data and model # init data and model
msa_len = 256 msa_len = 256
pair_len = 1024 pair_len = 256
node = torch.randn(1, msa_len, pair_len, 256).cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda() model = evoformer_base().cuda()
...@@ -106,11 +106,11 @@ def benchmark_evoformer(): ...@@ -106,11 +106,11 @@ def benchmark_evoformer():
# build openfold # build openfold
chunk_size = 64 chunk_size = 64
# openfold = _build_openfold() openfold = _build_openfold()
# benchmark # benchmark
# _benchmark_evoformer(model, node, pair, "base") _benchmark_evoformer(model, node, pair, "base")
# _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(autochunk, node, pair, "autochunk") _benchmark_evoformer(autochunk, node, pair, "autochunk")
......
...@@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc ...@@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from autochunk.evoformer.evoformer import evoformer_base from tests.test_autochunk.evoformer.evoformer import evoformer_base
from autochunk.chunk_codegen import ChunkCodeGen from ...colossalai.autochunk.chunk_codegen import ChunkCodeGen
with_codegen = True with_codegen = True
......
...@@ -19,25 +19,25 @@ import torch.nn as nn ...@@ -19,25 +19,25 @@ import torch.nn as nn
from typing import Tuple, Optional from typing import Tuple, Optional
from functools import partial from functools import partial
from openfold.primitives import Linear, LayerNorm from .primitives import Linear, LayerNorm
from openfold.dropout import DropoutRowwise, DropoutColumnwise from .dropout import DropoutRowwise, DropoutColumnwise
from openfold.msa import ( from .msa import (
MSARowAttentionWithPairBias, MSARowAttentionWithPairBias,
MSAColumnAttention, MSAColumnAttention,
MSAColumnGlobalAttention, MSAColumnGlobalAttention,
) )
from openfold.outer_product_mean import OuterProductMean from .outer_product_mean import OuterProductMean
from openfold.pair_transition import PairTransition from .pair_transition import PairTransition
from openfold.triangular_attention import ( from .triangular_attention import (
TriangleAttentionStartingNode, TriangleAttentionStartingNode,
TriangleAttentionEndingNode, TriangleAttentionEndingNode,
) )
from openfold.triangular_multiplicative_update import ( from .triangular_multiplicative_update import (
TriangleMultiplicationOutgoing, TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming, TriangleMultiplicationIncoming,
) )
from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn from .checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.tensor_utils import chunk_layer from .tensor_utils import chunk_layer
class MSATransition(nn.Module): class MSATransition(nn.Module):
......
...@@ -18,15 +18,15 @@ import torch ...@@ -18,15 +18,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from openfold.primitives import ( from .primitives import (
Linear, Linear,
LayerNorm, LayerNorm,
Attention, Attention,
GlobalAttention, GlobalAttention,
_attention_chunked_trainable, _attention_chunked_trainable,
) )
from openfold.checkpointing import get_checkpoint_fn from .checkpointing import get_checkpoint_fn
from openfold.tensor_utils import ( from .tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
......
...@@ -19,8 +19,8 @@ from typing import Optional ...@@ -19,8 +19,8 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.primitives import Linear from .primitives import Linear
from openfold.tensor_utils import chunk_layer from .tensor_utils import chunk_layer
class OuterProductMean(nn.Module): class OuterProductMean(nn.Module):
......
...@@ -17,8 +17,8 @@ from typing import Optional ...@@ -17,8 +17,8 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.primitives import Linear, LayerNorm from .primitives import Linear, LayerNorm
from openfold.tensor_utils import chunk_layer from .tensor_utils import chunk_layer
class PairTransition(nn.Module): class PairTransition(nn.Module):
......
...@@ -21,8 +21,8 @@ import numpy as np ...@@ -21,8 +21,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.checkpointing import get_checkpoint_fn from .checkpointing import get_checkpoint_fn
from openfold.tensor_utils import ( from .tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
_chunk_slice, _chunk_slice,
......
...@@ -20,8 +20,8 @@ from typing import Optional, List ...@@ -20,8 +20,8 @@ from typing import Optional, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.primitives import Linear, LayerNorm, Attention from .primitives import Linear, LayerNorm, Attention
from openfold.tensor_utils import ( from .tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
......
...@@ -19,8 +19,8 @@ from typing import Optional ...@@ -19,8 +19,8 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.primitives import Linear, LayerNorm from .primitives import Linear, LayerNorm
from openfold.tensor_utils import permute_final_dims from .tensor_utils import permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module): class TriangleMultiplicativeUpdate(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment