Unverified Commit 0e6ec2a4 authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

Extend Transformers Trainer Class to Enable PyTorch SGD/Adagrad Optimizers for Training (#17154)



* add torch SGD and Adagrad optimizer bits

* refine naming
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 63517fdf
...@@ -978,6 +978,10 @@ class Trainer: ...@@ -978,6 +978,10 @@ class Trainer:
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
except ImportError: except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
elif args.optim == OptimizerNames.SGD:
optimizer_cls = torch.optim.SGD
elif args.optim == OptimizerNames.ADAGRAD:
optimizer_cls = torch.optim.Adagrad
else: else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs return optimizer_cls, optimizer_kwargs
......
...@@ -87,6 +87,8 @@ class OptimizerNames(ExplicitEnum): ...@@ -87,6 +87,8 @@ class OptimizerNames(ExplicitEnum):
ADAMW_APEX_FUSED = "adamw_apex_fused" ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor" ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit" ADAMW_BNB = "adamw_bnb_8bit"
SGD = "sgd"
ADAGRAD = "adagrad"
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment