Commit 50b1854e authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

batch unet pass, wip

parent bb5ea026
...@@ -24,4 +24,5 @@ ...@@ -24,4 +24,5 @@
accelerate accelerate
diffusers diffusers
optimum[onnxruntime] optimum[onnxruntime]
transformers transformers
\ No newline at end of file protobuf==3.20.0
\ No newline at end of file
...@@ -123,8 +123,8 @@ class StableDiffusionMGX(): ...@@ -123,8 +123,8 @@ class StableDiffusionMGX():
"text_encoder", {"input_ids": [1, 77]}, fp16) "text_encoder", {"input_ids": [1, 77]}, fp16)
self.unet = StableDiffusionMGX.load_mgx_model( self.unet = StableDiffusionMGX.load_mgx_model(
"unet", { "unet", {
"sample": [1, 4, 64, 64], "sample": [2, 4, 64, 64],
"encoder_hidden_states": [1, 77, 1024], "encoder_hidden_states": [2, 77, 1024],
"timestep": [1], "timestep": [1],
}, fp16) }, fp16)
...@@ -220,32 +220,45 @@ class StableDiffusionMGX(): ...@@ -220,32 +220,45 @@ class StableDiffusionMGX():
def denoising_loop(self, text_embeddings, uncond_embeddings, latents, def denoising_loop(self, text_embeddings, uncond_embeddings, latents,
scale): scale):
for step, t in enumerate(self.scheduler.timesteps): for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step") # print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings, latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale) latents, t, scale)
return latents return latents
@measure # @measure
def denoise_step(self, text_embeddings, uncond_embeddings, latents, t, def denoise_step(self, text_embeddings, uncond_embeddings, latents, t,
scale): scale):
sample = self.scheduler.scale_model_input(latents, sample = self.scheduler.scale_model_input(latents,
t).numpy().astype(np.float32) t).numpy().astype(np.float32)
sample = np.concatenate((sample,sample))
encoder_hidden_states = np.concatenate((uncond_embeddings, text_embeddings))
timestep = np.atleast_1d(t.numpy().astype( timestep = np.atleast_1d(t.numpy().astype(
np.int64)) # convert 0D -> 1D np.int64)) # convert 0D -> 1D
noise_pred_uncond = np.array( start_time = time.perf_counter_ns()
noise_pred = np.array(
self.unet.run({ self.unet.run({
"sample": sample, "sample": sample,
"encoder_hidden_states": uncond_embeddings, "encoder_hidden_states": encoder_hidden_states,
"timestep": timestep "timestep": timestep
})[0]) })[0])
end_time = time.perf_counter_ns()
print(
f"Elapsed time for migx unet run: {(end_time - start_time) * 1e-6:.4f} ms\n"
)
noise_pred_text = np.array( noise_pred_split = np.split(noise_pred, 2)
self.unet.run({ noise_pred_uncond = noise_pred_split[0]
"sample": sample, noise_pred_text = noise_pred_split[1]
"encoder_hidden_states": text_embeddings,
"timestep": timestep # noise_pred_text = np.array(
})[0]) # self.unet.run({
# "sample": sample,
# "encoder_hidden_states": text_embeddings,
# "timestep": timestep
# })[0])
# perform guidance # perform guidance
noise_pred = noise_pred_uncond + scale * (noise_pred_text - noise_pred = noise_pred_uncond + scale * (noise_pred_text -
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment