"tests/benchmarks/micro_benchmarks/test_dist_inference.py" did not exist on "a9b45a072e6a1f4ec2bab2eb9cbdd575bb394037"
Unverified Commit 9c748527 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Code Revision - Replace torch.optim.AdamW with transformers.AdamW. (#106)

* replace torch.optim.AdamW with transformers.AdamW.
parent 05e449a3
......@@ -6,6 +6,7 @@
import os
import torch
import transformers
from torch.utils.data import DataLoader
from superbench.common.utils import logger
......@@ -139,7 +140,7 @@ def _create_optimizer(self):
elif self._optimizer_type == Optimizer.ADAM:
self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
elif self._optimizer_type == Optimizer.ADAMW:
self._optimizer = torch.optim.AdamW(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
self._optimizer = transformers.AdamW(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
else:
self._optimizer = None
......
......@@ -7,6 +7,7 @@
import numbers
import torch
import transformers
from tests.helper import decorator
from superbench.common.utils import logger
......@@ -220,7 +221,7 @@ def test_pytorch_base():
assert (benchmark._init_dataloader() is False)
# Test _create_optimizer().
assert (isinstance(benchmark._optimizer, torch.optim.AdamW))
assert (isinstance(benchmark._optimizer, transformers.AdamW))
benchmark._optimizer_type = Optimizer.ADAM
assert (benchmark._create_optimizer() is True)
assert (isinstance(benchmark._optimizer, torch.optim.Adam))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment