run_fp32_multi.sh 2.87 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/bin/bash
#for multinode
source `pwd`/config_DGX1_multi.sh
set -e

# start timing
start=$(date +%s)
start_fmt=$(date +%Y-%m-%d\ %r)
echo "STARTING TIMING RUN AT $start_fmt"
export NCCL_DEBUG=${NCCL_DEBUG:-"WARN"}

# run benchmark
set -x
DATASET_DIR='../wmt16_de_en/'
PREPROC_DATADIR='./preproc_data'
RESULTS_DIR='gnmt_wmt16'
## DL params
export LR=${LR:-"2.0e-3"}
export TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-64}
export TEST_BATCH_SIZE=${TEST_BATCH_SIZE:-64}
export WARMUP_STEPS=${WARMUP_STEPS:-200}
export REMAIN_STEPS=${REMAIN_STEPS:-6453}
export DECAY_INTERVAL=${DECAY_INTERVAL:-809}
export TARGET=${TARGET:-24.0}
export MAX_SEQ_LEN=${MAX_SEQ_LEN:-75}
export NUMEPOCHS=${NUMEPOCHS:-20}
export MATH=${MATH:-fp32}
export DIST_OPTS=${DIST_OPTS-"\
   --distributed-weight-update 2 \
   --dwu-num-blocks 1 \
   --dwu-num-chunks 2 \
   --dwu-num-rs-pg 2 \
   --dwu-num-ar-pg 2 \
   --dwu-num-ag-pg 0 \
   --dwu-grad-norm \
   "}
export EXTRA_OPTS=${EXTRA_OPTS-"\
   --fused-attention \
   --fused-xentropy \
   --no-log-all-ranks \
   "}

declare -a CMD


echo "running benchmark"

CMD_ARGS=("--save ${RESULTS_DIR}" "--dataset-dir ${DATASET_DIR}" "--preproc-data-dir ${PREPROC_DATADIR}/${MAX_SEQ_LEN}" "--target-bleu $TARGET" "--epochs "${NUMEPOCHS}"" "--math ${MATH}" "--max-length-train ${MAX_SEQ_LEN}" "--print-freq 10" "--train-batch-size $TRAIN_BATCH_SIZE" "--test-batch-size $TEST_BATCH_SIZE" "--optimizer FusedAdam" "--lr $LR" "--warmup-steps $WARMUP_STEPS" "--remain-steps $REMAIN_STEPS" "--decay-interval $DECAY_INTERVAL")

hostfile=./$SLURM_JOB_ID
scontrol show hostnames $SLURM_JOB_NODELIST > ${hostfile}
#rm `pwd`/hostfile-dl -f
cat ${hostfile} > `pwd`/tmp
dist_url=`sed -n '1p' ./tmp`
#echo $dist_url

rank=0
num_lines=`cat ./tmp |wc -l`
for((i=0;i<$num_lines-1;i++))
do
   ((rank=$i+1))
    nodename=$(cat ./tmp |sed -n "${rank}p")
   ssh ${nodename} "cd `pwd` && module rm compiler/rocm/2.9 && source /public/home/aiss/Pytorch/env_torch1.5_miopen24.sh && python3 -m bind_launch --nnodes=$num_lines --node_rank=$i  --master_addr=${dist_url} --master_port=4567 --nsockets_per_node ${DGXNSOCKET} --ncores_per_socket ${DGXSOCKETCORES} --nproc_per_node $SLURM_NTASKS_PER_NODE --no_hyperthreads `pwd`/train.py  ${CMD_ARGS[@]}" &
done

((i=$num_lines-1))
nodename=$(cat ./tmp |sed -n "${num_lines}p")
ssh ${nodename} "cd `pwd` && module rm compiler/rocm/2.9 && source /public/home/aiss/Pytorch/env_torch1.5_miopen24.sh && python3 -m bind_launch --nnodes=$num_lines --node_rank=$i  --master_addr=${dist_url} --master_port=4567 --nsockets_per_node ${DGXNSOCKET} --ncores_per_socket ${DGXSOCKETCORES} --nproc_per_node $SLURM_NTASKS_PER_NODE --no_hyperthreads `pwd`/train.py  ${CMD_ARGS[@]}"

set +x

sleep 3

# end timing
end=$(date +%s)
end_fmt=$(date +%Y-%m-%d\ %r)
echo "ENDING TIMING RUN AT $end_fmt"

# report result
result=$(( $end - $start ))
result_name="RNN_TRANSLATOR"

echo "RESULT,$result_name,,$result,nvidia,$start_fmt"