"vscode:/vscode.git/clone" did not exist on "805c5200dc41aa7ca8dbb851688223df8627b411"
Unverified Commit e0603d89 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

docs wrt using accelerate launcher with trainer (#24250)



* update docs

* missing part

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address comments

* address Zach's comment

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 23311314
......@@ -688,6 +688,156 @@ Finally, please, remember that, 🤗 `Trainer` only integrates MPS backend, ther
have any problems or questions with regards to MPS backend usage, please,
file an issue with [PyTorch GitHub](https://github.com/pytorch/pytorch/issues).
## Using Accelerate Launcher with Trainer
Accelerate now powers Trainer. In terms of what users should expect:
- They can keep using the Trainer ingterations such as FSDP, DeepSpeed vis trainer arguments without any changes on their part.
- They can now use Accelerate Launcher with Trainer (recommended).
Steps to use Accelerate Launcher with Trainer:
1. Make sure 🤗 Accelerate is installed, you can't use the `Trainer` without it anyway. If not `pip install accelerate`. You may also need to update your version of Accelerate: `pip install accelerate --upgrade`
2. Run `accelerate config` and fill the questionnaire. Below are example accelerate configs:
a. DDP Multi-node Multi-GPU config:
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0 #change rank as per the node
main_process_ip: 192.168.20.1
main_process_port: 9898
main_training_function: main
mixed_precision: fp16
num_machines: 2
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
b. FSDP config:
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
c. DeepSpeed config pointing to a file:
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: /home/user/configs/ds_zero3_config.json
zero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
d. DeepSpeed config using accelerate plugin:
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 0.7
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
3. Run the Trainer script with args other than the ones handled above by accelerate config or launcher args.
Below is an example to run `run_glue.py` using `accelerate launcher` with FSDP config from above.
```bash
cd transformers
accelerate launch \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
```
4. You can also directly use the cmd args for `accelerate launch`. Above example would map to:
```bash
cd transformers
accelerate launch --num_processes=2 \
--use_fsdp \
--mixed_precision=bf16 \
--fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
--fsdp_sharding_strategy=1 \
--fsdp_state_dict_type=FULL_STATE_DICT \
./examples/pytorch/text-classification/run_glue.py
--model_name_or_path bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
```
For more information, please refer the 🤗 Accelerate CLI guide: [Launching your 🤗 Accelerate scripts](https://huggingface.co/docs/accelerate/basic_tutorials/launch).
Sections that were moved:
[ <a href="./deepspeed#deepspeed-trainer-integration">DeepSpeed</a><a id="deepspeed"></a>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment