Unverified Commit eba49680 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

fix perf for api changing (#177)

parent 05681304
......@@ -12,7 +12,6 @@ def main():
parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark')
parser.add_argument("--dap-size", default=1, type=int, help='batch size')
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of MSA')
parser.add_argument('--res-length',
default=256,
......@@ -85,7 +84,9 @@ def main():
if args.openfold:
attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz))
else:
attn_layers.append(Evoformer(d_node=args.cm, d_pair=args.cz))
first_block = idx == 0
last_block = idx == args.layers - 1
attn_layers.append(Evoformer(c_m=args.cm, c_z=args.cz, first_block=first_block, last_block=last_block))
attn_layers[idx].cuda()
attn_layers[idx].to(dtype=precision)
......@@ -97,22 +98,23 @@ def main():
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))
inputs_node = torch.randn(args.batch_size,
args.msa_length // args.dap_size,
batch_size = 1
inputs_node = torch.randn(batch_size,
args.msa_length,
args.res_length,
args.cm,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
inputs_pair = torch.randn(args.batch_size,
args.res_length // args.dap_size,
inputs_pair = torch.randn(batch_size,
args.res_length,
args.res_length,
args.cz,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
node_mask = torch.ones((args.batch_size, args.msa_length, args.res_length),
node_mask = torch.ones((batch_size, args.msa_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
pair_mask = torch.ones((args.batch_size, args.res_length, args.res_length),
pair_mask = torch.ones((batch_size, args.res_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
grads_node = torch.randn_like(inputs_pair)
......@@ -129,6 +131,13 @@ def main():
with_stack=False)
prof.start()
if not args.openfold:
inputs_node = inputs_node.squeeze(0)
inputs_pair = inputs_pair.squeeze(0)
node_mask = node_mask.squeeze(0)
pair_mask = pair_mask.squeeze(0)
grads_node = grads_node.squeeze(0)
for trial in range(0, args.trials + args.warmup_trials):
layer_inputs = inputs_node, inputs_pair
evt_idx = trial - args.warmup_trials
......@@ -168,7 +177,7 @@ def main():
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])
print("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
args.batch_size, args.msa_length, args.res_length, \
batch_size, args.msa_length, args.res_length, \
args.cm, args.cz, \
elapsed_time_fwd / ( args.trials * args.layers ), \
elapsed_time_bwd / ( args.trials * args.layers )))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment