Unverified Commit 85494e88 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Pytorch] add dep. warning for pytorch schedulers (#651)

* add dep. warning for schedulers

* fix format
parent 33045382
......@@ -120,7 +120,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
......
......@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -112,7 +113,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -86,7 +87,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
# setable values
self.num_inference_steps: int = None
self.timesteps: np.ndarray = None
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -74,7 +75,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
......
......@@ -100,7 +100,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
......
......@@ -76,7 +76,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
# setable values
self.timesteps = None
......
......@@ -17,6 +17,7 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import math
import warnings
import torch
......@@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
import torch
......@@ -41,3 +42,12 @@ class SchedulerMixin:
"""
config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
warnings.warn(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch",
DeprecationWarning,
)
return self
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment