Commit f1823bbe authored by patil-suraj's avatar patil-suraj
Browse files

get the ldm pipeline working

parent e3820fa3
......@@ -773,7 +773,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
in_channels,
resolution,
z_channels,
n_embed,
embed_dim,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
......@@ -794,7 +793,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
n_embed=n_embed,
embed_dim=embed_dim,
remap=remap,
sane_index_shape=sane_index_shape,
......@@ -877,17 +875,16 @@ class LatentDiffusion(DiffusionPipeline):
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_embedding = self.bert(**text_input)[0]
text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.num_timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution // 8, self.unet.resolution // 8),
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1
train_step = inference_step_times[t]
......@@ -928,8 +925,8 @@ class LatentDiffusion(DiffusionPipeline):
else:
image = pred_mean
image = 1 / image
image = self.vqvae(image)
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
return image
......@@ -1026,7 +1026,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
hs = []
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = timestep_embedding(timesteps, self.model_channels)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment