"projects/Task001_Decathlon/README.md" did not exist on "4116e6adfb84c918033a51a0d0e9ee39eadb58d3"
Unverified Commit 9c748527 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Code Revision - Replace torch.optim.AdamW with transformers.AdamW. (#106)

* replace torch.optim.AdamW with transformers.AdamW.
parent 05e449a3
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import torch import torch
import transformers
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from superbench.common.utils import logger from superbench.common.utils import logger
...@@ -139,7 +140,7 @@ def _create_optimizer(self): ...@@ -139,7 +140,7 @@ def _create_optimizer(self):
elif self._optimizer_type == Optimizer.ADAM: elif self._optimizer_type == Optimizer.ADAM:
self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08) self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
elif self._optimizer_type == Optimizer.ADAMW: elif self._optimizer_type == Optimizer.ADAMW:
self._optimizer = torch.optim.AdamW(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08) self._optimizer = transformers.AdamW(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
else: else:
self._optimizer = None self._optimizer = None
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import numbers import numbers
import torch import torch
import transformers
from tests.helper import decorator from tests.helper import decorator
from superbench.common.utils import logger from superbench.common.utils import logger
...@@ -220,7 +221,7 @@ def test_pytorch_base(): ...@@ -220,7 +221,7 @@ def test_pytorch_base():
assert (benchmark._init_dataloader() is False) assert (benchmark._init_dataloader() is False)
# Test _create_optimizer(). # Test _create_optimizer().
assert (isinstance(benchmark._optimizer, torch.optim.AdamW)) assert (isinstance(benchmark._optimizer, transformers.AdamW))
benchmark._optimizer_type = Optimizer.ADAM benchmark._optimizer_type = Optimizer.ADAM
assert (benchmark._create_optimizer() is True) assert (benchmark._create_optimizer() is True)
assert (isinstance(benchmark._optimizer, torch.optim.Adam)) assert (isinstance(benchmark._optimizer, torch.optim.Adam))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment