Commit ba7ba4df authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

modify timings, extend script for fp16

parent a41cd5c0
......@@ -40,7 +40,7 @@ def measure(fn):
start_time = time.perf_counter_ns()
result = fn(*args, **kwargs)
end_time = time.perf_counter_ns()
print(f"Elapsed time: {(end_time - start_time) * 1e-6:.4f} ms\n")
print(f"Elapsed time for {fn.__name__}: {(end_time - start_time) * 1e-6:.4f} ms\n")
return result
return measure_ms
......@@ -87,6 +87,12 @@ def get_args():
help="Guidance scale",
)
parser.add_argument(
"--fp16",
action="store_true",
help="Quantize MIGraphX models to fp16"
)
parser.add_argument(
"-o",
"--output",
......@@ -98,7 +104,7 @@ def get_args():
class StableDiffusionMGX():
def __init__(self):
def __init__(self, fp16):
model_id = "stabilityai/stable-diffusion-2-1"
print(f"Using {model_id}")
......@@ -112,16 +118,17 @@ class StableDiffusionMGX():
print("Load models...")
self.vae = StableDiffusionMGX.load_mgx_model(
"vae_decoder", {"latent_sample": [1, 4, 64, 64]})
"vae_decoder", {"latent_sample": [1, 4, 64, 64]}, fp16)
self.text_encoder = StableDiffusionMGX.load_mgx_model(
"text_encoder", {"input_ids": [1, 77]})
"text_encoder", {"input_ids": [1, 77]}, fp16)
self.unet = StableDiffusionMGX.load_mgx_model(
"unet", {
"sample": [1, 4, 64, 64],
"encoder_hidden_states": [1, 77, 1024],
"timestep": [1],
})
}, fp16)
@measure
def run(self, prompt, negative_prompt, steps, seed, scale):
# need to set this for each run
self.scheduler.set_timesteps(steps)
......@@ -148,10 +155,7 @@ class StableDiffusionMGX():
latents = latents * self.scheduler.init_noise_sigma
print("Running denoising loop...")
for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
latents = self.denoising_loop(text_embeddings, uncond_embeddings, latents, scale)
print("Scale denoised result...")
latents = 1 / 0.18215 * latents
......@@ -163,21 +167,25 @@ class StableDiffusionMGX():
@staticmethod
@measure
def load_mgx_model(name, shapes):
def load_mgx_model(name, shapes, fp16):
file = f"models/sd21-onnx/{name}/model"
if fp16:
file += "_fp16"
print(f"Loading {name} model from {file}")
if os.path.isfile(f"{file}.mxr"):
print("Found mxr, loading it...")
model = mgx.load(f"{file}.mxr", format="msgpack")
elif os.path.isfile(f"{file}.onnx"):
elif os.path.isfile(f"{file.rstrip('''_fp16''')}.onnx"):
print("Parsing from onnx file...")
model = mgx.parse_onnx(f"{file}.onnx", map_input_dims=shapes)
model = mgx.parse_onnx(f"{file.rstrip('''_fp16''')}.onnx", map_input_dims=shapes)
if fp16:
mgx.quantize_fp16(model)
model.compile(mgx.get_target("gpu"))
print(f"Saving {name} model to mxr file...")
mgx.save(model, f"{file}.mxr", format="msgpack")
else:
print(f"No {name} model found. Please download it and re-try.")
os.exit(1)
exit(1)
return model
@measure
......@@ -206,6 +214,14 @@ class StableDiffusionMGX():
def save_image(pil_image, filename="output.png"):
pil_image.save(filename)
@measure
def denoising_loop(self, text_embeddings, uncond_embeddings, latents, scale):
for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
return latents
@measure
def denoise_step(self, text_embeddings, uncond_embeddings, latents, t,
scale):
......@@ -246,7 +262,7 @@ class StableDiffusionMGX():
if __name__ == "__main__":
args = get_args()
sd = StableDiffusionMGX()
sd = StableDiffusionMGX(args.fp16)
result = sd.run(args.prompt, args.negative_prompt, args.steps, args.seed,
args.scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment