"docs/vscode:/vscode.git/clone" did not exist on "3047ab97ef858ef0faeeaf6e9f43f40b87f0e5fc"
Unverified Commit 5652c43f authored by nupurkmr9's avatar nupurkmr9 Committed by GitHub
Browse files

Resolve bf16 error as mentioned in this...

Resolve bf16 error as mentioned in this [issue](https://github.com/huggingface/diffusers/issues/4139#issuecomment-1639977304) (#4214)

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error
parent 365e8461
...@@ -277,4 +277,4 @@ To save even more memory, pass the `--set_grads_to_none` argument to the script. ...@@ -277,4 +277,4 @@ To save even more memory, pass the `--set_grads_to_none` argument to the script.
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
## Experimental results ## Experimental results
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. We also released a more extensive dataset of 101 concepts for evaluating model customization methods. For more details please refer to our [dataset webpage](https://www.cs.cmu.edu/~custom-diffusion/dataset.html).
\ No newline at end of file \ No newline at end of file
...@@ -57,7 +57,7 @@ def retrieve(class_prompt, class_data_dir, num_class_images): ...@@ -57,7 +57,7 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
images = class_images[count] images = class_images[count]
count += 1 count += 1
try: try:
img = requests.get(images["url"]) img = requests.get(images["url"], timeout=30)
if img.status_code == 200: if img.status_code == 200:
_ = Image.open(BytesIO(img.content)) _ = Image.open(BytesIO(img.content))
with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f: with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
......
...@@ -1210,6 +1210,7 @@ def main(args): ...@@ -1210,6 +1210,7 @@ def main(args):
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer, tokenizer=tokenizer,
revision=args.revision, revision=args.revision,
torch_dtype=weight_dtype,
) )
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
......
...@@ -667,9 +667,9 @@ class CustomDiffusionAttnProcessor(nn.Module): ...@@ -667,9 +667,9 @@ class CustomDiffusionAttnProcessor(nn.Module):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out: if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states) query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else: else:
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None: if encoder_hidden_states is None:
crossattn = False crossattn = False
...@@ -680,8 +680,10 @@ class CustomDiffusionAttnProcessor(nn.Module): ...@@ -680,8 +680,10 @@ class CustomDiffusionAttnProcessor(nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv: if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states) key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else: else:
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -1392,9 +1394,9 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): ...@@ -1392,9 +1394,9 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out: if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states) query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else: else:
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None: if encoder_hidden_states is None:
crossattn = False crossattn = False
...@@ -1405,8 +1407,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): ...@@ -1405,8 +1407,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv: if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states) key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else: else:
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment