Unverified Commit e6ff7528 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

Add npu support (#7144)

* Add npu support

* fix for code quality check

* fix for code quality check
parent 3f9c746f
...@@ -11,6 +11,7 @@ from ..utils import ( ...@@ -11,6 +11,7 @@ from ..utils import (
is_note_seq_available, is_note_seq_available,
is_onnx_available, is_onnx_available,
is_torch_available, is_torch_available,
is_torch_npu_available,
is_transformers_available, is_transformers_available,
) )
......
...@@ -53,12 +53,19 @@ from ..utils import ( ...@@ -53,12 +53,19 @@ from ..utils import (
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_torch_npu_available,
is_torch_version, is_torch_version,
logging, logging,
numpy_to_pil, numpy_to_pil,
) )
from ..utils.hub_utils import load_or_create_model_card, populate_model_card from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module from ..utils.torch_utils import is_compiled_module
if is_torch_npu_available():
import torch_npu # noqa: F401
from .pipeline_loading_utils import ( from .pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES, ALL_IMPORTABLE_CLASSES,
CONNECTED_PIPES_KEYS, CONNECTED_PIPES_KEYS,
......
...@@ -12,6 +12,7 @@ from .utils import ( ...@@ -12,6 +12,7 @@ from .utils import (
convert_state_dict_to_peft, convert_state_dict_to_peft,
deprecate, deprecate,
is_peft_available, is_peft_available,
is_torch_npu_available,
is_torchvision_available, is_torchvision_available,
is_transformers_available, is_transformers_available,
) )
...@@ -26,6 +27,9 @@ if is_peft_available(): ...@@ -26,6 +27,9 @@ if is_peft_available():
if is_torchvision_available(): if is_torchvision_available():
from torchvision import transforms from torchvision import transforms
if is_torch_npu_available():
import torch_npu # noqa: F401
def set_seed(seed: int): def set_seed(seed: int):
""" """
...@@ -36,8 +40,11 @@ def set_seed(seed: int): ...@@ -36,8 +40,11 @@ def set_seed(seed: int):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) if is_torch_npu_available():
# ^^ safe to call this function even if cuda is not available torch.npu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
def compute_snr(noise_scheduler, timesteps): def compute_snr(noise_scheduler, timesteps):
......
...@@ -72,6 +72,7 @@ from .import_utils import ( ...@@ -72,6 +72,7 @@ from .import_utils import (
is_scipy_available, is_scipy_available,
is_tensorboard_available, is_tensorboard_available,
is_torch_available, is_torch_available,
is_torch_npu_available,
is_torch_version, is_torch_version,
is_torch_xla_available, is_torch_xla_available,
is_torchsde_available, is_torchsde_available,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
""" """
Import utilities: Utilities related to imports and our lazy inits. Import utilities: Utilities related to imports and our lazy inits.
""" """
import importlib.util import importlib.util
import operator as op import operator as op
import os import os
...@@ -72,6 +73,15 @@ if _torch_xla_available: ...@@ -72,6 +73,15 @@ if _torch_xla_available:
except ImportError: except ImportError:
_torch_xla_available = False _torch_xla_available = False
# check whether torch_npu is available
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
if _torch_npu_available:
try:
_torch_npu_version = importlib_metadata.version("torch_npu")
logger.info(f"torch_npu version {_torch_npu_version} available.")
except ImportError:
_torch_npu_available = False
_jax_version = "N/A" _jax_version = "N/A"
_flax_version = "N/A" _flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
...@@ -294,6 +304,10 @@ def is_torch_xla_available(): ...@@ -294,6 +304,10 @@ def is_torch_xla_available():
return _torch_xla_available return _torch_xla_available
def is_torch_npu_available():
return _torch_npu_available
def is_flax_available(): def is_flax_available():
return _flax_available return _flax_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment