Unverified Commit 3e8106d2 authored by Gilad Turok's avatar Gilad Turok Committed by GitHub
Browse files

Docs: fix GaLore optimizer code example (#32249)

Docs: fix GaLore optimizer example

Fix incorrect usage of GaLore optimizer in Transformers trainer code example.

The GaLore optimizer uses low-rank gradient updates to reduce memory usage. GaLore is quite popular and is implemented by the authors in [https://github.com/jiaweizzhao/GaLore](https://github.com/jiaweizzhao/GaLore). A few months ago GaLore was added to the HuggingFace Transformers library in https://github.com/huggingface/transformers/pull/29588.

Documentation of the Trainer module includes a few code examples of how to use GaLore. However, the `optim_targe_modules` argument to the `TrainingArguments` function is incorrect, as discussed in https://github.com/huggingface/transformers/pull/29588#issuecomment-2006289512. This pull request fixes this issue.
parent f0bc49e7
...@@ -278,7 +278,7 @@ args = TrainingArguments( ...@@ -278,7 +278,7 @@ args = TrainingArguments(
max_steps=100, max_steps=100,
per_device_train_batch_size=2, per_device_train_batch_size=2,
optim="galore_adamw", optim="galore_adamw",
optim_target_modules=["attn", "mlp"] optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
) )
model_id = "google/gemma-2b" model_id = "google/gemma-2b"
...@@ -315,7 +315,7 @@ args = TrainingArguments( ...@@ -315,7 +315,7 @@ args = TrainingArguments(
max_steps=100, max_steps=100,
per_device_train_batch_size=2, per_device_train_batch_size=2,
optim="galore_adamw", optim="galore_adamw",
optim_target_modules=["attn", "mlp"], optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
optim_args="rank=64, update_proj_gap=100, scale=0.10", optim_args="rank=64, update_proj_gap=100, scale=0.10",
) )
...@@ -359,7 +359,7 @@ args = TrainingArguments( ...@@ -359,7 +359,7 @@ args = TrainingArguments(
max_steps=100, max_steps=100,
per_device_train_batch_size=2, per_device_train_batch_size=2,
optim="galore_adamw_layerwise", optim="galore_adamw_layerwise",
optim_target_modules=["attn", "mlp"] optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
) )
model_id = "google/gemma-2b" model_id = "google/gemma-2b"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment