"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "3fdecc873991c7696a55eb518225b2bd85cbaac2"
Commit c4cccb6b authored by Jennifer's avatar Jennifer
Browse files

adds deepspeed_evoformer flag to inference script.

parent 19525826
......@@ -62,7 +62,8 @@ def model_config(
name,
train=False,
low_prec=False,
long_sequence_inference=False
long_sequence_inference=False,
use_deepspeed_evoformer_attention=False,
):
c = copy.deepcopy(config)
# TRAINING PRESETS
......@@ -237,6 +238,9 @@ def model_config(
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
c.model.evoformer_stack.tune_chunk_size = False
if use_deepspeed_evoformer_attention:
c.globals.use_deepspeed_evo_attention = True
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
......
......@@ -179,7 +179,11 @@ def main(args):
if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
config = model_config(
args.config_preset,
long_sequence_inference=args.long_sequence_inference,
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
......@@ -462,6 +466,10 @@ if __name__ == "__main__":
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser.add_argument(
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
)
add_data_args(parser)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment