Unverified Commit 543ee1e0 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[LDMTextToImagePipeline] make text model generic (#162)

make text model generic
parent 75b6c165
...@@ -45,11 +45,11 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -45,11 +45,11 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0: if guidance_scale != 1.0:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device)) uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(torch_device)) text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
...@@ -618,5 +618,4 @@ class LDMBertModel(LDMBertPreTrainedModel): ...@@ -618,5 +618,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
sequence_output = outputs[0] return outputs
return sequence_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment