Commit 50b1854e authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

batch unet pass, wip

parent bb5ea026
......@@ -24,4 +24,5 @@
accelerate
diffusers
optimum[onnxruntime]
transformers
\ No newline at end of file
transformers
protobuf==3.20.0
\ No newline at end of file
......@@ -123,8 +123,8 @@ class StableDiffusionMGX():
"text_encoder", {"input_ids": [1, 77]}, fp16)
self.unet = StableDiffusionMGX.load_mgx_model(
"unet", {
"sample": [1, 4, 64, 64],
"encoder_hidden_states": [1, 77, 1024],
"sample": [2, 4, 64, 64],
"encoder_hidden_states": [2, 77, 1024],
"timestep": [1],
}, fp16)
......@@ -220,32 +220,45 @@ class StableDiffusionMGX():
def denoising_loop(self, text_embeddings, uncond_embeddings, latents,
scale):
for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step")
# print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
return latents
@measure
# @measure
def denoise_step(self, text_embeddings, uncond_embeddings, latents, t,
scale):
sample = self.scheduler.scale_model_input(latents,
t).numpy().astype(np.float32)
sample = np.concatenate((sample,sample))
encoder_hidden_states = np.concatenate((uncond_embeddings, text_embeddings))
timestep = np.atleast_1d(t.numpy().astype(
np.int64)) # convert 0D -> 1D
noise_pred_uncond = np.array(
start_time = time.perf_counter_ns()
noise_pred = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": uncond_embeddings,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep
})[0])
end_time = time.perf_counter_ns()
print(
f"Elapsed time for migx unet run: {(end_time - start_time) * 1e-6:.4f} ms\n"
)
noise_pred_text = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": text_embeddings,
"timestep": timestep
})[0])
noise_pred_split = np.split(noise_pred, 2)
noise_pred_uncond = noise_pred_split[0]
noise_pred_text = noise_pred_split[1]
# noise_pred_text = np.array(
# self.unet.run({
# "sample": sample,
# "encoder_hidden_states": text_embeddings,
# "timestep": timestep
# })[0])
# perform guidance
noise_pred = noise_pred_uncond + scale * (noise_pred_text -
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment