Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
85494e88
Unverified
Commit
85494e88
authored
Sep 27, 2022
by
Kashif Rasul
Committed by
GitHub
Sep 27, 2022
Browse files
[Pytorch] add dep. warning for pytorch schedulers (#651)
* add dep. warning for schedulers * fix format
parent
33045382
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
69 additions
and
1 deletion
+69
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+8
-0
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+9
-0
src/diffusers/schedulers/scheduling_karras_ve.py
src/diffusers/schedulers/scheduling_karras_ve.py
+9
-0
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+9
-0
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+8
-0
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+8
-0
src/diffusers/schedulers/scheduling_sde_vp.py
src/diffusers/schedulers/scheduling_sde_vp.py
+8
-1
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+10
-0
No files found.
src/diffusers/schedulers/scheduling_ddim.py
View file @
85494e88
...
...
@@ -120,7 +120,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample
:
bool
=
True
,
set_alpha_to_one
:
bool
=
True
,
steps_offset
:
int
=
0
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
if
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
85494e88
...
...
@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
...
...
@@ -112,7 +113,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
variance_type
:
str
=
"fixed_small"
,
clip_sample
:
bool
=
True
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
elif
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_karras_ve.py
View file @
85494e88
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
...
...
@@ -86,7 +87,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_churn
:
float
=
80
,
s_min
:
float
=
0.05
,
s_max
:
float
=
50
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
# setable values
self
.
num_inference_steps
:
int
=
None
self
.
timesteps
:
np
.
ndarray
=
None
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
85494e88
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
...
...
@@ -74,7 +75,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
if
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
85494e88
...
...
@@ -100,7 +100,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps
:
bool
=
False
,
set_alpha_to_one
:
bool
=
False
,
steps_offset
:
int
=
0
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
if
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_sde_ve.py
View file @
85494e88
...
...
@@ -76,7 +76,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma_max
:
float
=
1348.0
,
sampling_eps
:
float
=
1e-5
,
correct_steps
:
int
=
1
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
# setable values
self
.
timesteps
=
None
...
...
src/diffusers/schedulers/scheduling_sde_vp.py
View file @
85494e88
...
...
@@ -17,6 +17,7 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import
math
import
warnings
import
torch
...
...
@@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
2000
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
):
def
__init__
(
self
,
num_train_timesteps
=
2000
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
,
**
kwargs
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
self
.
sigmas
=
None
self
.
discrete_sigmas
=
None
self
.
timesteps
=
None
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
85494e88
...
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
import
torch
...
...
@@ -41,3 +42,12 @@ class SchedulerMixin:
"""
config_name
=
SCHEDULER_CONFIG_NAME
def
set_format
(
self
,
tensor_format
=
"pt"
):
warnings
.
warn
(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch"
,
DeprecationWarning
,
)
return
self
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment