"src/tl_templates/vscode:/vscode.git/clone" did not exist on "5ccac4fa53c2c0ab7cdd0e0bb8f0965d8b670682"
Commit ba7ba4df authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

modify timings, extend script for fp16

parent a41cd5c0
...@@ -40,7 +40,7 @@ def measure(fn): ...@@ -40,7 +40,7 @@ def measure(fn):
start_time = time.perf_counter_ns() start_time = time.perf_counter_ns()
result = fn(*args, **kwargs) result = fn(*args, **kwargs)
end_time = time.perf_counter_ns() end_time = time.perf_counter_ns()
print(f"Elapsed time: {(end_time - start_time) * 1e-6:.4f} ms\n") print(f"Elapsed time for {fn.__name__}: {(end_time - start_time) * 1e-6:.4f} ms\n")
return result return result
return measure_ms return measure_ms
...@@ -87,6 +87,12 @@ def get_args(): ...@@ -87,6 +87,12 @@ def get_args():
help="Guidance scale", help="Guidance scale",
) )
parser.add_argument(
"--fp16",
action="store_true",
help="Quantize MIGraphX models to fp16"
)
parser.add_argument( parser.add_argument(
"-o", "-o",
"--output", "--output",
...@@ -98,7 +104,7 @@ def get_args(): ...@@ -98,7 +104,7 @@ def get_args():
class StableDiffusionMGX(): class StableDiffusionMGX():
def __init__(self): def __init__(self, fp16):
model_id = "stabilityai/stable-diffusion-2-1" model_id = "stabilityai/stable-diffusion-2-1"
print(f"Using {model_id}") print(f"Using {model_id}")
...@@ -112,16 +118,17 @@ class StableDiffusionMGX(): ...@@ -112,16 +118,17 @@ class StableDiffusionMGX():
print("Load models...") print("Load models...")
self.vae = StableDiffusionMGX.load_mgx_model( self.vae = StableDiffusionMGX.load_mgx_model(
"vae_decoder", {"latent_sample": [1, 4, 64, 64]}) "vae_decoder", {"latent_sample": [1, 4, 64, 64]}, fp16)
self.text_encoder = StableDiffusionMGX.load_mgx_model( self.text_encoder = StableDiffusionMGX.load_mgx_model(
"text_encoder", {"input_ids": [1, 77]}) "text_encoder", {"input_ids": [1, 77]}, fp16)
self.unet = StableDiffusionMGX.load_mgx_model( self.unet = StableDiffusionMGX.load_mgx_model(
"unet", { "unet", {
"sample": [1, 4, 64, 64], "sample": [1, 4, 64, 64],
"encoder_hidden_states": [1, 77, 1024], "encoder_hidden_states": [1, 77, 1024],
"timestep": [1], "timestep": [1],
}) }, fp16)
@measure
def run(self, prompt, negative_prompt, steps, seed, scale): def run(self, prompt, negative_prompt, steps, seed, scale):
# need to set this for each run # need to set this for each run
self.scheduler.set_timesteps(steps) self.scheduler.set_timesteps(steps)
...@@ -148,10 +155,7 @@ class StableDiffusionMGX(): ...@@ -148,10 +155,7 @@ class StableDiffusionMGX():
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
print("Running denoising loop...") print("Running denoising loop...")
for step, t in enumerate(self.scheduler.timesteps): latents = self.denoising_loop(text_embeddings, uncond_embeddings, latents, scale)
print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
print("Scale denoised result...") print("Scale denoised result...")
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
...@@ -163,21 +167,25 @@ class StableDiffusionMGX(): ...@@ -163,21 +167,25 @@ class StableDiffusionMGX():
@staticmethod @staticmethod
@measure @measure
def load_mgx_model(name, shapes): def load_mgx_model(name, shapes, fp16):
file = f"models/sd21-onnx/{name}/model" file = f"models/sd21-onnx/{name}/model"
if fp16:
file += "_fp16"
print(f"Loading {name} model from {file}") print(f"Loading {name} model from {file}")
if os.path.isfile(f"{file}.mxr"): if os.path.isfile(f"{file}.mxr"):
print("Found mxr, loading it...") print("Found mxr, loading it...")
model = mgx.load(f"{file}.mxr", format="msgpack") model = mgx.load(f"{file}.mxr", format="msgpack")
elif os.path.isfile(f"{file}.onnx"): elif os.path.isfile(f"{file.rstrip('''_fp16''')}.onnx"):
print("Parsing from onnx file...") print("Parsing from onnx file...")
model = mgx.parse_onnx(f"{file}.onnx", map_input_dims=shapes) model = mgx.parse_onnx(f"{file.rstrip('''_fp16''')}.onnx", map_input_dims=shapes)
if fp16:
mgx.quantize_fp16(model)
model.compile(mgx.get_target("gpu")) model.compile(mgx.get_target("gpu"))
print(f"Saving {name} model to mxr file...") print(f"Saving {name} model to mxr file...")
mgx.save(model, f"{file}.mxr", format="msgpack") mgx.save(model, f"{file}.mxr", format="msgpack")
else: else:
print(f"No {name} model found. Please download it and re-try.") print(f"No {name} model found. Please download it and re-try.")
os.exit(1) exit(1)
return model return model
@measure @measure
...@@ -206,6 +214,14 @@ class StableDiffusionMGX(): ...@@ -206,6 +214,14 @@ class StableDiffusionMGX():
def save_image(pil_image, filename="output.png"): def save_image(pil_image, filename="output.png"):
pil_image.save(filename) pil_image.save(filename)
@measure
def denoising_loop(self, text_embeddings, uncond_embeddings, latents, scale):
for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
return latents
@measure @measure
def denoise_step(self, text_embeddings, uncond_embeddings, latents, t, def denoise_step(self, text_embeddings, uncond_embeddings, latents, t,
scale): scale):
...@@ -246,7 +262,7 @@ class StableDiffusionMGX(): ...@@ -246,7 +262,7 @@ class StableDiffusionMGX():
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
sd = StableDiffusionMGX() sd = StableDiffusionMGX(args.fp16)
result = sd.run(args.prompt, args.negative_prompt, args.steps, args.seed, result = sd.run(args.prompt, args.negative_prompt, args.steps, args.seed,
args.scale) args.scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment