Commit cb4eecb0 authored by Marta's avatar Marta
Browse files

specify cluster_environment for DeepSpeed if running with SLURM

parent 4b1c4488
......@@ -14,6 +14,7 @@ import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
from openfold.config import model_config
......@@ -158,7 +159,14 @@ def main(args):
callbacks.append(perf)
if(args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin(config=args.deepspeed_config_path)
if "SLURM_JOB_ID" in os.environ:
cluster_environment = SLURMEnvironment()
else:
cluster_environment = None
strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path,
cluster_environment=cluster_environment,
)
elif args.gpus > 1 or args.num_nodes > 1:
strategy = "ddp"
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment