Unverified Commit 20f0cbc8 authored by Viktor Grygorchuk's avatar Viktor Grygorchuk Committed by GitHub
Browse files

fix: error on device for `lpw_stable_diffusion_xl` pipeline if...

fix: error on device for `lpw_stable_diffusion_xl` pipeline if `pipe.enable_sequential_cpu_offload()` enabled (#5885)

fix: set device for pipe.enable_sequential_cpu_offload()
parent d72a24b7
...@@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl( ...@@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl(
neg_prompt: str = "", neg_prompt: str = "",
neg_prompt_2: str = None, neg_prompt_2: str = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
): ):
""" """
This function can process long prompt with weights, no length limitation This function can process long prompt with weights, no length limitation
...@@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl( ...@@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl(
neg_prompt (str) neg_prompt (str)
neg_prompt_2 (str) neg_prompt_2 (str)
num_images_per_prompt (int) num_images_per_prompt (int)
device (torch.device)
Returns: Returns:
prompt_embeds (torch.Tensor) prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor) neg_prompt_embeds (torch.Tensor)
""" """
device = device or pipe._execution_device
if prompt_2: if prompt_2:
prompt = f"{prompt} {prompt_2}" prompt = f"{prompt} {prompt_2}"
...@@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl( ...@@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl(
# get prompt embeddings one by one is not working. # get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)): for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights # get positive prompt embeddings with weights
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device) token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device) weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device) token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
# use first text encoder # use first text encoder
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True) prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2] prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
# use second text encoder # use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True) prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2] prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds = prompt_embeds_2[0] pooled_prompt_embeds = prompt_embeds_2[0]
...@@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl( ...@@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl(
embeds.append(token_embedding) embeds.append(token_embedding)
# get negative prompt embeddings with weights # get negative prompt embeddings with weights
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device) neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device) neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device) neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
# use first text encoder # use first text encoder
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True) neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2] neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
# use second text encoder # use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True) neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2] neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0] negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment