Unverified Commit 688031c5 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix import with Flax but without PyTorch (#688)

* Don't use `load_state_dict` if torch is not installed.

* Define `SchedulerOutput` to use torch or flax arrays.

* Don't import LMSDiscreteScheduler without torch.

* Create distinct FlaxSchedulerOutput.

* Additional changes required for FlaxSchedulerMixin

* Do not import torch pipelines in Flax.

* Revert "Define `SchedulerOutput` to use torch or flax arrays."

This reverts commit f653140134b74d9ffec46d970eb46925fe3a409d.

* Prefix Flax scheduler outputs for consistency.

* make style

* FlaxSchedulerOutput is now a dataclass.

* Don't use f-string without placeholders.

* Add blank line.

* Style (docstrings)
parent 7d0ba592
......@@ -73,6 +73,7 @@ if is_flax_available():
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler,
)
else:
......
......@@ -27,8 +27,8 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from . import is_torch_available
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import load_state_dict
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
......@@ -391,6 +391,14 @@ class FlaxModelMixin:
)
if from_pt:
if is_torch_available():
from .modeling_utils import load_state_dict
else:
raise EnvironmentError(
"Can't load the model in PyTorch format because PyTorch is not installed. "
"Please, install PyTorch or use native Flax weights."
)
# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)
......
......@@ -30,7 +30,7 @@ from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
......@@ -46,7 +46,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"],
"FlaxSchedulerMixin": ["save_config", "from_config"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
......@@ -436,7 +436,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
else:
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin):
elif issubclass(class_obj, FlaxSchedulerMixin):
loaded_sub_model, scheduler_state = load_method(loadable_folder)
params[name] = scheduler_state
else:
......
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
if is_torch_available():
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
else:
from ..utils.dummy_pt_objects import * # noqa F403
if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
......
......@@ -6,7 +6,7 @@ import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
@dataclass
......@@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected: List[bool]
if is_transformers_available():
if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
......
......@@ -34,10 +34,12 @@ if is_flax_available():
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin
else:
from ..utils.dummy_flax_objects import * # noqa F403
if is_scipy_available():
if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
......@@ -23,7 +23,7 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
......@@ -68,11 +68,11 @@ class DDIMSchedulerState:
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
state: DDIMSchedulerState
class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
......@@ -183,7 +183,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -197,11 +197,11 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
......@@ -252,7 +252,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
......
......@@ -23,7 +23,7 @@ import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
......@@ -67,11 +67,11 @@ class DDPMSchedulerState:
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
state: DDPMSchedulerState
class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
......@@ -191,7 +191,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -205,11 +205,11 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
t = timestep
......@@ -257,7 +257,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (pred_prev_sample, state)
return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state)
return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)
def add_noise(
self,
......
......@@ -22,7 +22,7 @@ from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
from .scheduling_utils_flax import FlaxSchedulerMixin
@flax.struct.dataclass
......@@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput):
state: KarrasVeSchedulerState
class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
......@@ -172,7 +172,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns:
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
......@@ -211,7 +211,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
......
......@@ -20,7 +20,7 @@ import jax.numpy as jnp
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
@flax.struct.dataclass
......@@ -37,11 +37,11 @@ class LMSDiscreteSchedulerState:
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
state: LMSDiscreteSchedulerState
class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Katherine Crowson:
......@@ -147,7 +147,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample: jnp.ndarray,
order: int = 4,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[FlaxLMSSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -159,11 +159,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
sigma = state.sigmas[timestep]
......@@ -189,7 +189,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
......
......@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
......@@ -76,11 +76,11 @@ class PNDMSchedulerState:
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
state: PNDMSchedulerState
class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method.
......@@ -211,7 +211,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -224,11 +224,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.config.skip_prk_steps:
......@@ -249,7 +249,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
def step_prk(
self,
......@@ -257,7 +257,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
) -> Union[FlaxSchedulerOutput, Tuple]:
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
......@@ -268,11 +268,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
......@@ -327,7 +327,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
) -> Union[FlaxSchedulerOutput, Tuple]:
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
......@@ -338,11 +338,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
......
......@@ -22,7 +22,7 @@ import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
@flax.struct.dataclass
......@@ -38,7 +38,7 @@ class ScoreSdeVeSchedulerState:
@dataclass
class FlaxSdeVeOutput(SchedulerOutput):
class FlaxSdeVeOutput(FlaxSchedulerOutput):
"""
Output class for the ScoreSdeVeScheduler's step function output.
......@@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput):
prev_sample_mean: Optional[jnp.ndarray] = None
class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
The variance exploding stochastic differential equation (SDE) scheduler.
......@@ -168,7 +168,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
......@@ -216,7 +216,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sample: jnp.ndarray,
key: random.KeyArray,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
......@@ -227,7 +227,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
import jax.numpy as jnp
from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@dataclass
class FlaxSchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
Args:
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: jnp.ndarray
class FlaxSchedulerMixin:
"""
Mixin containing common functions for the schedulers.
"""
config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
warnings.warn(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch",
DeprecationWarning,
)
return self
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment