Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c3369f56
Unverified
Commit
c3369f56
authored
Feb 01, 2024
by
Patrick von Platen
Committed by
GitHub
Jan 31, 2024
Browse files
fix torchvision import (#6796)
parent
04cd6adf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
1 deletion
+9
-1
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+9
-1
No files found.
src/diffusers/training_utils.py
View file @
c3369f56
...
@@ -5,7 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union
...
@@ -5,7 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
t
orchvision
import
transforms
from
t
ransformers
import
is_torchvision_available
from
.models
import
UNet2DConditionModel
from
.models
import
UNet2DConditionModel
from
.utils
import
(
from
.utils
import
(
...
@@ -23,6 +23,9 @@ if is_transformers_available():
...
@@ -23,6 +23,9 @@ if is_transformers_available():
if
is_peft_available
():
if
is_peft_available
():
from
peft
import
set_peft_model_state_dict
from
peft
import
set_peft_model_state_dict
if
is_torchvision_available
():
from
torchvision
import
transforms
def
set_seed
(
seed
:
int
):
def
set_seed
(
seed
:
int
):
"""
"""
...
@@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
...
@@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
transform.
transform.
"""
"""
if
not
is_torchvision_available
():
raise
ImportError
(
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
)
if
interpolation_type
==
"bilinear"
:
if
interpolation_type
==
"bilinear"
:
interpolation_mode
=
transforms
.
InterpolationMode
.
BILINEAR
interpolation_mode
=
transforms
.
InterpolationMode
.
BILINEAR
elif
interpolation_type
==
"bicubic"
:
elif
interpolation_type
==
"bicubic"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment