Unverified Commit 7bbdfd7b authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix accelerate logger bug (#23650)



* fix logger bug

* Update tests/mixed_int8/test_mixed_int8.py
Co-authored-by: default avatarZachary Mueller <muellerzr@gmail.com>

* import `PartialState`

---------
Co-authored-by: default avatarZachary Mueller <muellerzr@gmail.com>
parent 29294b0e
...@@ -29,6 +29,7 @@ from transformers import ( ...@@ -29,6 +29,7 @@ from transformers import (
pipeline, pipeline,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
is_accelerate_available,
is_torch_available, is_torch_available,
require_accelerate, require_accelerate,
require_bitsandbytes, require_bitsandbytes,
...@@ -40,6 +41,13 @@ from transformers.testing_utils import ( ...@@ -40,6 +41,13 @@ from transformers.testing_utils import (
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
if is_accelerate_available():
from accelerate import PartialState
from accelerate.logging import get_logger
logger = get_logger(__name__)
_ = PartialState()
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn as nn import torch.nn as nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment