Commit c4cccb6b authored by Jennifer's avatar Jennifer
Browse files

adds deepspeed_evoformer flag to inference script.

parent 19525826
...@@ -62,7 +62,8 @@ def model_config( ...@@ -62,7 +62,8 @@ def model_config(
name, name,
train=False, train=False,
low_prec=False, low_prec=False,
long_sequence_inference=False long_sequence_inference=False,
use_deepspeed_evoformer_attention=False,
): ):
c = copy.deepcopy(config) c = copy.deepcopy(config)
# TRAINING PRESETS # TRAINING PRESETS
...@@ -237,6 +238,9 @@ def model_config( ...@@ -237,6 +238,9 @@ def model_config(
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False c.model.evoformer_stack.tune_chunk_size = False
if use_deepspeed_evoformer_attention:
c.globals.use_deepspeed_evo_attention = True
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
......
...@@ -179,7 +179,11 @@ def main(args): ...@@ -179,7 +179,11 @@ def main(args):
if args.config_preset.startswith("seq"): if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True args.use_single_seq_mode = True
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(
args.config_preset,
long_sequence_inference=args.long_sequence_inference,
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
)
if args.experiment_config_json: if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f: with open(args.experiment_config_json, 'r') as f:
...@@ -462,6 +466,10 @@ if __name__ == "__main__": ...@@ -462,6 +466,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
) )
parser.add_argument(
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment