Unverified Commit 71ba8aec authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Pipeline to device (#210)



* Implement `pipeline.to(device)`

* DiffusionPipeline.to() decides best device on None.

* Breaking change: torch_device removed from __call__

`pipeline.to()` now has PyTorch semantics.

* Use kwargs and deprecation notice
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Apply torch_device compatibility to all pipelines.

* style
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avataranton-l <anton@huggingface.co>
parent 89e95210
...@@ -19,6 +19,8 @@ import inspect ...@@ -19,6 +19,8 @@ import inspect
import os import os
from typing import Optional, Union from typing import Optional, Union
import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
...@@ -113,6 +115,26 @@ class DiffusionPipeline(ConfigMixin): ...@@ -113,6 +115,26 @@ class DiffusionPipeline(ConfigMixin):
save_method = getattr(sub_model, save_method_name) save_method = getattr(sub_model, save_method_name)
save_method(os.path.join(save_directory, pipeline_component_name)) save_method(os.path.join(save_directory, pipeline_component_name))
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
return self
module_names, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
module.to(torch_device)
return self
@property
def device(self) -> torch.device:
module_names, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
return module.device
return torch.device("cpu")
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import warnings
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -28,21 +30,28 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -28,21 +30,28 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
): if "torch_device" in kwargs:
# eta corresponds to η in paper and should be between [0, 1] device = kwargs.pop("torch_device")
if torch_device is None: warnings.warn(
torch_device = "cuda" if torch.cuda.is_available() else "cpu" "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
self.unet.to(torch_device) # Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# eta corresponds to η in paper and should be between [0, 1]
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import warnings
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -28,18 +30,25 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -28,18 +30,25 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"): def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
if torch_device is None: if "torch_device" in kwargs:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
self.unet.to(torch_device) # Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(1000) self.scheduler.set_timesteps(1000)
......
import inspect import inspect
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -31,13 +32,22 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -31,13 +32,22 @@ class LDMTextToImagePipeline(DiffusionPipeline):
guidance_scale: Optional[float] = 1.0, guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
**kwargs,
): ):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if "torch_device" in kwargs:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -49,24 +59,20 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -49,24 +59,20 @@ class LDMTextToImagePipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0: if guidance_scale != 1.0:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0] uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0] text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8), (batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator, generator=generator,
) )
latents = latents.to(torch_device) latents = latents.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
import inspect import inspect
import warnings
import torch import torch
...@@ -14,22 +15,26 @@ class LDMPipeline(DiffusionPipeline): ...@@ -14,22 +15,26 @@ class LDMPipeline(DiffusionPipeline):
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if "torch_device" in kwargs:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
self.unet.to(torch_device) # Set device as before (to be removed in 0.3.0)
self.vqvae.to(torch_device) if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
latents = latents.to(torch_device) latents = latents.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import warnings
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -28,20 +30,28 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -28,20 +30,28 @@ class PNDMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"): def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs):
# For more information on the sampling method you can take a look at Algorithm 2 of # For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf # the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device) if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps): for t in tqdm(self.scheduler.timesteps):
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import warnings
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
...@@ -11,24 +13,32 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -11,24 +13,32 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"): def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs):
if "torch_device" in kwargs:
if torch_device is None: device = kwargs.pop("torch_device")
torch_device = "cuda" if torch.cuda.is_available() else "cpu" warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
img_size = self.unet.config.sample_size img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size) shape = (batch_size, 3, img_size, img_size)
model = self.unet.to(torch_device) model = self.unet
sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(torch_device) sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps)
for i, t in tqdm(enumerate(self.scheduler.timesteps)): for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device) sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
# correction step # correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
......
import inspect import inspect
import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
...@@ -45,11 +46,20 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -45,11 +46,20 @@ class StableDiffusionPipeline(DiffusionPipeline):
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
**kwargs,
): ):
if torch_device is None: if "torch_device" in kwargs:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -61,11 +71,6 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -61,11 +71,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
self.unet.to(torch_device)
self.vae.to(torch_device)
self.text_encoder.to(torch_device)
self.safety_checker.to(torch_device)
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_input = self.tokenizer(
prompt, prompt,
...@@ -74,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -74,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -86,7 +91,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -86,7 +91,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
) )
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0] uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
...@@ -97,7 +102,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -97,7 +102,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8), (batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator, generator=generator,
device=torch_device, device=self.device,
) )
# set timesteps # set timesteps
...@@ -150,7 +155,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -150,7 +155,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker # run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device) safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
if output_type == "pil": if output_type == "pil":
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import warnings
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -27,18 +29,27 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -27,18 +29,27 @@ class KarrasVePipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"): def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs):
if torch_device is None: if "torch_device" in kwargs:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
img_size = self.unet.config.sample_size img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size) shape = (batch_size, 3, img_size, img_size)
model = self.unet.to(torch_device) model = self.unet
# sample x_0 ~ N(0, sigma_0^2 * I) # sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(torch_device) sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment