"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0a2fecdf90081102641d6b08795ddc0333443b32"
Unverified Commit 4fdf47cd authored by statelesshz's avatar statelesshz Committed by GitHub
Browse files

Extend Trainer to enable Ascend NPU to use the fused Adamw optimizer when training (#26194)

parent fc296f41
...@@ -1068,6 +1068,14 @@ class Trainer: ...@@ -1068,6 +1068,14 @@ class Trainer:
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
except ImportError: except ImportError:
raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED:
try:
from torch_npu.optim import NpuFusedAdamW
optimizer_cls = NpuFusedAdamW
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
try: try:
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam
......
...@@ -140,6 +140,7 @@ class OptimizerNames(ExplicitEnum): ...@@ -140,6 +140,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH = "adamw_torch" ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused" ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla" ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
ADAMW_APEX_FUSED = "adamw_apex_fused" ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor" ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision" ADAMW_ANYPRECISION = "adamw_anyprecision"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment