Unverified Commit 793362fb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #3 from aqlaboratory/marta-sd/slurm

Manually specify cluster_environment for DeepSpeed if running with SLURM
parents 4b1c4488 cb4eecb0
......@@ -14,6 +14,7 @@ import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
from openfold.config import model_config
......@@ -158,7 +159,14 @@ def main(args):
callbacks.append(perf)
if(args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin(config=args.deepspeed_config_path)
if "SLURM_JOB_ID" in os.environ:
cluster_environment = SLURMEnvironment()
else:
cluster_environment = None
strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path,
cluster_environment=cluster_environment,
)
elif args.gpus > 1 or args.num_nodes > 1:
strategy = "ddp"
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment