Unverified Commit 5652c43f authored by nupurkmr9's avatar nupurkmr9 Committed by GitHub
Browse files

Resolve bf16 error as mentioned in this...

Resolve bf16 error as mentioned in this [issue](https://github.com/huggingface/diffusers/issues/4139#issuecomment-1639977304) (#4214)

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error
parent 365e8461
...@@ -277,4 +277,4 @@ To save even more memory, pass the `--set_grads_to_none` argument to the script. ...@@ -277,4 +277,4 @@ To save even more memory, pass the `--set_grads_to_none` argument to the script.
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
## Experimental results ## Experimental results
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. We also released a more extensive dataset of 101 concepts for evaluating model customization methods. For more details please refer to our [dataset webpage](https://www.cs.cmu.edu/~custom-diffusion/dataset.html).
\ No newline at end of file \ No newline at end of file
...@@ -57,7 +57,7 @@ def retrieve(class_prompt, class_data_dir, num_class_images): ...@@ -57,7 +57,7 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
images = class_images[count] images = class_images[count]
count += 1 count += 1
try: try:
img = requests.get(images["url"]) img = requests.get(images["url"], timeout=30)
if img.status_code == 200: if img.status_code == 200:
_ = Image.open(BytesIO(img.content)) _ = Image.open(BytesIO(img.content))
with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f: with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
......
...@@ -1210,6 +1210,7 @@ def main(args): ...@@ -1210,6 +1210,7 @@ def main(args):
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer, tokenizer=tokenizer,
revision=args.revision, revision=args.revision,
torch_dtype=weight_dtype,
) )
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
......
...@@ -667,9 +667,9 @@ class CustomDiffusionAttnProcessor(nn.Module): ...@@ -667,9 +667,9 @@ class CustomDiffusionAttnProcessor(nn.Module):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out: if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states) query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else: else:
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None: if encoder_hidden_states is None:
crossattn = False crossattn = False
...@@ -680,8 +680,10 @@ class CustomDiffusionAttnProcessor(nn.Module): ...@@ -680,8 +680,10 @@ class CustomDiffusionAttnProcessor(nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv: if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states) key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else: else:
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -1392,9 +1394,9 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): ...@@ -1392,9 +1394,9 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out: if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states) query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else: else:
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None: if encoder_hidden_states is None:
crossattn = False crossattn = False
...@@ -1405,8 +1407,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): ...@@ -1405,8 +1407,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv: if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states) key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else: else:
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment