Unverified Commit b671cb09 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Remove deprecated `torch_device` kwarg (#623)

* Remove deprecated `torch_device` kwarg.

* Remove unused imports.
parent bb0c5d15
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -74,20 +73,6 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -74,20 +73,6 @@ class DDIMPipeline(DiffusionPipeline):
generated images. generated images.
""" """
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# eta corresponds to η in paper and should be between [0, 1]
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
...@@ -103,6 +88,7 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -103,6 +88,7 @@ class DDIMPipeline(DiffusionPipeline):
model_output = self.unet(image, t).sample model_output = self.unet(image, t).sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta # 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1 # do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample image = self.scheduler.step(model_output, t, image, eta).prev_sample
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -66,17 +65,6 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -66,17 +65,6 @@ class DDPMPipeline(DiffusionPipeline):
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images. generated images.
""" """
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
......
import inspect import inspect
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -94,17 +93,6 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -94,17 +93,6 @@ class LDMTextToImagePipeline(DiffusionPipeline):
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images. generated images.
""" """
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
import inspect import inspect
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -60,18 +59,6 @@ class LDMPipeline(DiffusionPipeline): ...@@ -60,18 +59,6 @@ class LDMPipeline(DiffusionPipeline):
generated images. generated images.
""" """
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -75,18 +74,6 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -75,18 +74,6 @@ class PNDMPipeline(DiffusionPipeline):
# For more information on the sampling method you can take a look at Algorithm 2 of # For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf # the official paper: https://arxiv.org/pdf/2202.09778.pdf
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -53,18 +52,6 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -53,18 +52,6 @@ class ScoreSdeVePipeline(DiffusionPipeline):
generated images. generated images.
""" """
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
img_size = self.unet.config.sample_size img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size) shape = (batch_size, 3, img_size, img_size)
......
...@@ -169,18 +169,6 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -169,18 +169,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif isinstance(prompt, list):
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -64,17 +63,6 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -64,17 +63,6 @@ class KarrasVePipeline(DiffusionPipeline):
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images. generated images.
""" """
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
img_size = self.unet.config.sample_size img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size) shape = (batch_size, 3, img_size, img_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment