"sgl-kernel/vscode:/vscode.git/clone" did not exist on "704ced1b2ec45a87bb42494fe9da18fa5004e546"
Unverified Commit 5652c43f authored by nupurkmr9's avatar nupurkmr9 Committed by GitHub
Browse files

Resolve bf16 error as mentioned in this...

Resolve bf16 error as mentioned in this [issue](https://github.com/huggingface/diffusers/issues/4139#issuecomment-1639977304) (#4214)

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error

* resolve bf16 error
parent 365e8461
......@@ -277,4 +277,4 @@ To save even more memory, pass the `--set_grads_to_none` argument to the script.
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
## Experimental results
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail.
\ No newline at end of file
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. We also released a more extensive dataset of 101 concepts for evaluating model customization methods. For more details please refer to our [dataset webpage](https://www.cs.cmu.edu/~custom-diffusion/dataset.html).
\ No newline at end of file
......@@ -57,7 +57,7 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
images = class_images[count]
count += 1
try:
img = requests.get(images["url"])
img = requests.get(images["url"], timeout=30)
if img.status_code == 200:
_ = Image.open(BytesIO(img.content))
with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
......
......@@ -1210,6 +1210,7 @@ def main(args):
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
......
......@@ -667,9 +667,9 @@ class CustomDiffusionAttnProcessor(nn.Module):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states)
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None:
crossattn = False
......@@ -680,8 +680,10 @@ class CustomDiffusionAttnProcessor(nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states)
value = self.to_v_custom_diffusion(encoder_hidden_states)
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
......@@ -1392,9 +1394,9 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states)
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None:
crossattn = False
......@@ -1405,8 +1407,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states)
value = self.to_v_custom_diffusion(encoder_hidden_states)
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment