train-vicuna.yaml 4.17 KB
Newer Older
zhaoying1's avatar
zhaoying1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
resources:
  accelerators: A100-80GB:8
  disk_size: 1000
  use_spot: true

num_nodes: 1

file_mounts:
  /artifacts:
    name: skypilot-chatbot # Change to your own bucket
    store: gcs
    mode: MOUNT
  /data:
    name: model-weights # Change to your own bucket
    store: gcs
    mode: MOUNT
  # /llamma:
  #   name: llama-ckpts # Change to the bucket that contains the LLaMA weights
  #   store: gcs
  #   mode: MOUNT

workdir: .

setup: |
  # Setup the environment
  conda create -n chatbot python=3.10 -y
  conda activate chatbot

  # Install pytorch
  pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

  # Install huggingface with the LLaMA commit
  cd ~
  git clone https://github.com/huggingface/transformers.git
  cd transformers
  git checkout 41a2f3529c6b56866c317031375ffd3e7b8bea01
  pip install .
  cd ~/sky_workdir

  # Install fastchat
  pip install -e .
  pip install flash-attn

  mkdir -p /artifacts/llama-hf/llama-${MODEL_SIZE}B
  if [ ! -f /artifacts/llama-hf/llama-${MODEL_SIZE}B/complete ]; then
    mkdir -p ~/llama-${MODEL_SIZE}b
    gsutil -m rsync -r /llama/${MODEL_SIZE}b/ ~/llama-${MODEL_SIZE}b
    cd ~/transformers
    python src/transformers/models/llama/convert_llama_weights_to_hf.py \
      --input_dir $HOME/llama-${MODEL_SIZE}b \
      --model_size ${MODEL_SIZE}B \
      --output_dir ~/hf-output || exit 1
    mv ~/hf-output/tokenizer/* ~/hf-output/llama-${MODEL_SIZE}b
    gsutil -m rsync -r ~/hf-output/llama-${MODEL_SIZE}b/ /artifacts/llama-hf/llama-${MODEL_SIZE}B
    touch /artifacts/llama-hf/llama-${MODEL_SIZE}B/complete
  else
    mkdir -p ~/hf-output/llama-${MODEL_SIZE}b
    gsutil -m cp -r /artifacts/llama-hf/llama-${MODEL_SIZE}B/* ~/hf-output/llama-${MODEL_SIZE}b
  fi

run: |
  conda activate chatbot
  SEQ_LEN=${SEQ_LEN:-512}
  GC_SCALE=${GC_SCALE:-1}
  DATE=${DATE:-20230303}
  USE_FLASH_ATTN=${USE_FLASH_ATTN:-0}
  if [ $USE_FLASH_ATTN -eq 1 ]; then
    TRAIN_SCRIPT=fastchat/train/train_mem.py
    USE_FLASH_SUFFIX="-flash"
  else
    TRAIN_SCRIPT=fastchat/train/train.py
    USE_FLASH_SUFFIX=""
  fi
  echo "Training with seq_len=${SEQ_LEN} and gc_scale=${GC_SCALE}"
  PER_DEVICE_BATCH_SIZE=$((2048 * $GC_SCALE / $SEQ_LEN))
  NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
  HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`

  # Do the periodic syncing manually, to avoid the degradation of
  # the training for saving checkpoints.
  mkdir -p ~/.checkpoints
  LOCAL_CKPT_PATH=~/.checkpoints
  CKPT_PATH=/artifacts/chatbot/${MODEL_SIZE}b/sharegpt-${DATE}-seq-${SEQ_LEN}${USE_FLASH_SUFFIX}
  last_ckpt=$(ls ${CKPT_PATH} | grep -E '[0-9]+' | sort -t'-' -k1,1 -k2,2n | tail -1)
  mkdir -p ~/.checkpoints/${last_ckpt}
  gsutil -m rsync -r ${CKPT_PATH}/${last_ckpt}/ ~/.checkpoints/${last_ckpt}

  bash scripts/sync_local_checkpoint.sh ${LOCAL_CKPT_PATH} ${CKPT_PATH} > sync.log 2>&1 &
  
  torchrun \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
    --master_port=12375 \
    --master_addr=$HOST_ADDR \
    --node_rank=${SKYPILOT_NODE_RANK} \
    $TRAIN_SCRIPT \
    --model_name_or_path ~/hf-output/llama-${MODEL_SIZE}b \
    --data_path /data/sharegpt/sharegpt_20230322_clean_lang_split.json \
    --bf16 True \
    --output_dir $LOCAL_CKPT_PATH \
    --num_train_epochs 3 \
    --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
    --per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
    --gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1200 \
    --save_total_limit 10 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length ${SEQ_LEN} \
    --gradient_checkpointing True \
    --lazy_preprocess True

  # Sync any files not in the checkpoint-* folders
  gsutil -m rsync -r -x 'checkpoint-*' $LOCAL_CKPT_PATH/ $CKPT_PATH/


envs:
  MODEL_SIZE: 13
  SEQ_LEN: 2048
  GC_SCALE: 4
  DATE: 20230322
  USE_FLASH_ATTN: 1