Commit 7a23deb5 authored by oahzxl's avatar oahzxl
Browse files

code style

parent 5a916c0a
...@@ -34,15 +34,23 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): ...@@ -34,15 +34,23 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title):
def benchmark_evoformer(): def benchmark_evoformer():
# data # init data and model
msa_len = 300 msa_len = 300
pair_len = 800 pair_len = 800
node = torch.randn(1, msa_len, pair_len, 256).cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda()
# build gm model # build autochunk model
max_memory = 3000 # MB max_memory = 3000 # MB
model = evoformer_base().cuda() autochunk = _build_autochunk(model, max_memory, node, pair)
# benchmark
_benchmark_evoformer(model, node, pair, "openfold")
_benchmark_evoformer(autochunk, node, pair, "autochunk")
def _build_autochunk(model, max_memory, node, pair):
# trace the module and replace codegen # trace the module and replace codegen
graph = ColoTracer().trace( graph = ColoTracer().trace(
model, model,
...@@ -70,9 +78,7 @@ def benchmark_evoformer(): ...@@ -70,9 +78,7 @@ def benchmark_evoformer():
# print # print
code = graph.python_code("self").src code = graph.python_code("self").src
print(code) print(code)
return gm
_benchmark_evoformer(gm, node, pair, "autochunk")
_benchmark_evoformer(model, node, pair, "openfold")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment