Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e6ff7528
Unverified
Commit
e6ff7528
authored
Mar 09, 2024
by
Mengqing Cao
Committed by
GitHub
Mar 08, 2024
Browse files
Add npu support (#7144)
* Add npu support * fix for code quality check * fix for code quality check
parent
3f9c746f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
2 deletions
+32
-2
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+7
-0
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+9
-2
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+1
-0
src/diffusers/utils/import_utils.py
src/diffusers/utils/import_utils.py
+14
-0
No files found.
src/diffusers/pipelines/__init__.py
View file @
e6ff7528
...
@@ -11,6 +11,7 @@ from ..utils import (
...
@@ -11,6 +11,7 @@ from ..utils import (
is_note_seq_available
,
is_note_seq_available
,
is_onnx_available
,
is_onnx_available
,
is_torch_available
,
is_torch_available
,
is_torch_npu_available
,
is_transformers_available
,
is_transformers_available
,
)
)
...
...
src/diffusers/pipelines/pipeline_utils.py
View file @
e6ff7528
...
@@ -53,12 +53,19 @@ from ..utils import (
...
@@ -53,12 +53,19 @@ from ..utils import (
deprecate
,
deprecate
,
is_accelerate_available
,
is_accelerate_available
,
is_accelerate_version
,
is_accelerate_version
,
is_torch_npu_available
,
is_torch_version
,
is_torch_version
,
logging
,
logging
,
numpy_to_pil
,
numpy_to_pil
,
)
)
from
..utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
..utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
..utils.torch_utils
import
is_compiled_module
from
..utils.torch_utils
import
is_compiled_module
if
is_torch_npu_available
():
import
torch_npu
# noqa: F401
from
.pipeline_loading_utils
import
(
from
.pipeline_loading_utils
import
(
ALL_IMPORTABLE_CLASSES
,
ALL_IMPORTABLE_CLASSES
,
CONNECTED_PIPES_KEYS
,
CONNECTED_PIPES_KEYS
,
...
...
src/diffusers/training_utils.py
View file @
e6ff7528
...
@@ -12,6 +12,7 @@ from .utils import (
...
@@ -12,6 +12,7 @@ from .utils import (
convert_state_dict_to_peft
,
convert_state_dict_to_peft
,
deprecate
,
deprecate
,
is_peft_available
,
is_peft_available
,
is_torch_npu_available
,
is_torchvision_available
,
is_torchvision_available
,
is_transformers_available
,
is_transformers_available
,
)
)
...
@@ -26,6 +27,9 @@ if is_peft_available():
...
@@ -26,6 +27,9 @@ if is_peft_available():
if
is_torchvision_available
():
if
is_torchvision_available
():
from
torchvision
import
transforms
from
torchvision
import
transforms
if
is_torch_npu_available
():
import
torch_npu
# noqa: F401
def
set_seed
(
seed
:
int
):
def
set_seed
(
seed
:
int
):
"""
"""
...
@@ -36,8 +40,11 @@ def set_seed(seed: int):
...
@@ -36,8 +40,11 @@ def set_seed(seed: int):
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
if
is_torch_npu_available
():
# ^^ safe to call this function even if cuda is not available
torch
.
npu
.
manual_seed_all
(
seed
)
else
:
torch
.
cuda
.
manual_seed_all
(
seed
)
# ^^ safe to call this function even if cuda is not available
def
compute_snr
(
noise_scheduler
,
timesteps
):
def
compute_snr
(
noise_scheduler
,
timesteps
):
...
...
src/diffusers/utils/__init__.py
View file @
e6ff7528
...
@@ -72,6 +72,7 @@ from .import_utils import (
...
@@ -72,6 +72,7 @@ from .import_utils import (
is_scipy_available
,
is_scipy_available
,
is_tensorboard_available
,
is_tensorboard_available
,
is_torch_available
,
is_torch_available
,
is_torch_npu_available
,
is_torch_version
,
is_torch_version
,
is_torch_xla_available
,
is_torch_xla_available
,
is_torchsde_available
,
is_torchsde_available
,
...
...
src/diffusers/utils/import_utils.py
View file @
e6ff7528
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""
"""
Import utilities: Utilities related to imports and our lazy inits.
Import utilities: Utilities related to imports and our lazy inits.
"""
"""
import
importlib.util
import
importlib.util
import
operator
as
op
import
operator
as
op
import
os
import
os
...
@@ -72,6 +73,15 @@ if _torch_xla_available:
...
@@ -72,6 +73,15 @@ if _torch_xla_available:
except
ImportError
:
except
ImportError
:
_torch_xla_available
=
False
_torch_xla_available
=
False
# check whether torch_npu is available
_torch_npu_available
=
importlib
.
util
.
find_spec
(
"torch_npu"
)
is
not
None
if
_torch_npu_available
:
try
:
_torch_npu_version
=
importlib_metadata
.
version
(
"torch_npu"
)
logger
.
info
(
f
"torch_npu version
{
_torch_npu_version
}
available."
)
except
ImportError
:
_torch_npu_available
=
False
_jax_version
=
"N/A"
_jax_version
=
"N/A"
_flax_version
=
"N/A"
_flax_version
=
"N/A"
if
USE_JAX
in
ENV_VARS_TRUE_AND_AUTO_VALUES
:
if
USE_JAX
in
ENV_VARS_TRUE_AND_AUTO_VALUES
:
...
@@ -294,6 +304,10 @@ def is_torch_xla_available():
...
@@ -294,6 +304,10 @@ def is_torch_xla_available():
return
_torch_xla_available
return
_torch_xla_available
def
is_torch_npu_available
():
return
_torch_npu_available
def
is_flax_available
():
def
is_flax_available
():
return
_flax_available
return
_flax_available
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment