Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
85494e88
Unverified
Commit
85494e88
authored
Sep 27, 2022
by
Kashif Rasul
Committed by
GitHub
Sep 27, 2022
Browse files
[Pytorch] add dep. warning for pytorch schedulers (#651)
* add dep. warning for schedulers * fix format
parent
33045382
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
69 additions
and
1 deletion
+69
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+8
-0
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+9
-0
src/diffusers/schedulers/scheduling_karras_ve.py
src/diffusers/schedulers/scheduling_karras_ve.py
+9
-0
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+9
-0
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+8
-0
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+8
-0
src/diffusers/schedulers/scheduling_sde_vp.py
src/diffusers/schedulers/scheduling_sde_vp.py
+8
-1
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+10
-0
No files found.
src/diffusers/schedulers/scheduling_ddim.py
View file @
85494e88
...
...
@@ -120,7 +120,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample
:
bool
=
True
,
set_alpha_to_one
:
bool
=
True
,
steps_offset
:
int
=
0
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
if
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
85494e88
...
...
@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
...
...
@@ -112,7 +113,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
variance_type
:
str
=
"fixed_small"
,
clip_sample
:
bool
=
True
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
elif
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_karras_ve.py
View file @
85494e88
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
...
...
@@ -86,7 +87,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_churn
:
float
=
80
,
s_min
:
float
=
0.05
,
s_max
:
float
=
50
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
# setable values
self
.
num_inference_steps
:
int
=
None
self
.
timesteps
:
np
.
ndarray
=
None
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
85494e88
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
...
...
@@ -74,7 +75,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
if
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
85494e88
...
...
@@ -100,7 +100,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps
:
bool
=
False
,
set_alpha_to_one
:
bool
=
False
,
steps_offset
:
int
=
0
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
if
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_sde_ve.py
View file @
85494e88
...
...
@@ -76,7 +76,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma_max
:
float
=
1348.0
,
sampling_eps
:
float
=
1e-5
,
correct_steps
:
int
=
1
,
**
kwargs
,
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
# setable values
self
.
timesteps
=
None
...
...
src/diffusers/schedulers/scheduling_sde_vp.py
View file @
85494e88
...
...
@@ -17,6 +17,7 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import
math
import
warnings
import
torch
...
...
@@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
=
2000
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
):
def
__init__
(
self
,
num_train_timesteps
=
2000
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
,
**
kwargs
):
if
"tensor_format"
in
kwargs
:
warnings
.
warn
(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument."
,
DeprecationWarning
,
)
self
.
sigmas
=
None
self
.
discrete_sigmas
=
None
self
.
timesteps
=
None
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
85494e88
...
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
import
torch
...
...
@@ -41,3 +42,12 @@ class SchedulerMixin:
"""
config_name
=
SCHEDULER_CONFIG_NAME
def
set_format
(
self
,
tensor_format
=
"pt"
):
warnings
.
warn
(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch"
,
DeprecationWarning
,
)
return
self
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment