Unverified Commit 90096804 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add slurm support for distributed training (#508)

* implement _init_dist_slurm()

* add slurm train/test scripts

* fix linting error

* minor fix
parent 53c647ea
import logging import logging
import os import os
import random import random
import subprocess
import numpy as np import numpy as np
import torch import torch
...@@ -34,8 +35,19 @@ def _init_dist_mpi(backend, **kwargs): ...@@ -34,8 +35,19 @@ def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError raise NotImplementedError
def _init_dist_slurm(backend, **kwargs): def _init_dist_slurm(backend, port=29500, **kwargs):
raise NotImplementedError proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def set_random_seed(seed): def set_random_seed(seed):
......
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
CHECKPOINT=$4
GPUS=${GPUS:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-32}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS} \
--ntasks=1 \
--ntasks-per-node=1 \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python tools/test.py ${CONFIG} ${CHECKPOINT} --gpus ${GPUS} ${PY_ARGS}
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
WORK_DIR=$4
GPUS=${5:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${PY_ARGS:-"--validate"}
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/train.py ${CONFIG} --work_dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment