"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "394a48d16955d06bd2255adc28bf7c322156707c"
Unverified Commit 85494e88 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Pytorch] add dep. warning for pytorch schedulers (#651)

* add dep. warning for schedulers

* fix format
parent 33045382
...@@ -120,7 +120,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -120,7 +120,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True, clip_sample: bool = True,
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
**kwargs,
): ):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -112,7 +113,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,7 +113,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
**kwargs,
): ):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -86,7 +87,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -86,7 +87,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_churn: float = 80, s_churn: float = 80,
s_min: float = 0.05, s_min: float = 0.05,
s_max: float = 50, s_max: float = 50,
**kwargs,
): ):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
# setable values # setable values
self.num_inference_steps: int = None self.num_inference_steps: int = None
self.timesteps: np.ndarray = None self.timesteps: np.ndarray = None
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -74,7 +75,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -74,7 +75,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
**kwargs,
): ):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
......
...@@ -100,7 +100,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -100,7 +100,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
set_alpha_to_one: bool = False, set_alpha_to_one: bool = False,
steps_offset: int = 0, steps_offset: int = 0,
**kwargs,
): ):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
......
...@@ -76,7 +76,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -76,7 +76,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma_max: float = 1348.0, sigma_max: float = 1348.0,
sampling_eps: float = 1e-5, sampling_eps: float = 1e-5,
correct_steps: int = 1, correct_steps: int = 1,
**kwargs,
): ):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
# setable values # setable values
self.timesteps = None self.timesteps = None
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import math import math
import warnings
import torch import torch
...@@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
""" """
@register_to_config @register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
self.sigmas = None self.sigmas = None
self.discrete_sigmas = None self.discrete_sigmas = None
self.timesteps = None self.timesteps = None
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
...@@ -41,3 +42,12 @@ class SchedulerMixin: ...@@ -41,3 +42,12 @@ class SchedulerMixin:
""" """
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
warnings.warn(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch",
DeprecationWarning,
)
return self
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment