Unverified Commit 408b2b3c authored by Nathan Lambert's avatar Nathan Lambert Committed by GitHub
Browse files

Add torch `RMSProp` optimizer (#26425)

add rmsprop
parent 6ba63ac3
......@@ -1139,6 +1139,8 @@ class Trainer:
optimizer_cls = torch.optim.SGD
elif args.optim == OptimizerNames.ADAGRAD:
optimizer_cls = torch.optim.Adagrad
elif args.optim == OptimizerNames.RMSPROP:
optimizer_cls = torch.optim.RMSprop
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
......
......@@ -153,6 +153,7 @@ class OptimizerNames(ExplicitEnum):
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment