Unverified Commit eda36c4c authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

Fix dtype error for StableDiffusionXL (#9217)



Fix dtype error
Co-authored-by: default avatar蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 803e817e
......@@ -1084,7 +1084,7 @@ def main(args):
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
# time ids
def compute_time_ids(original_size, crops_coords_top_left):
......@@ -1101,7 +1101,7 @@ def main(args):
# Predict the noise residual
unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment