Unverified Commit a24ed87e authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into optimize_jenkinsfile

parents 6481cd69 a09dc502
...@@ -46,11 +46,11 @@ Trim instructions from the end (Default: 0) ...@@ -46,11 +46,11 @@ Trim instructions from the end (Default: 0)
Dim of a parameter (format: "@name d1 d2 dn") Dim of a parameter (format: "@name d1 d2 dn")
.. options:: --dyn-input-dim [std::vector<std::string>] .. option:: --dyn-input-dim [std::vector<std::string>]
Set dynamic dimensions of a parameter using JSON formatting (format "@name" "dynamic_dimension_json") Set dynamic dimensions of a parameter using JSON formatting (format "@name" "dynamic_dimension_json")
.. options:: --default-dyn-dim .. option:: --default-dyn-dim
Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]}) Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
...@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]}) ...@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
Optimize when reading Optimize when reading
.. option:: --apply-pass, -p
Passes to apply to model
.. option:: --graphviz, -g .. option:: --graphviz, -g
Print out a graphviz representation. Print out a graphviz representation.
......
...@@ -8,45 +8,65 @@ shape ...@@ -8,45 +8,65 @@ shape
.. doxygenenum:: migraphx_shape_datatype_t .. doxygenenum:: migraphx_shape_datatype_t
.. doxygenstruct:: migraphx::shape .. doxygenstruct:: migraphx::shape
:members:
:undoc-members:
argument argument
-------- --------
.. doxygenstruct:: migraphx::argument .. doxygenstruct:: migraphx::argument
:members:
:undoc-members:
target target
------ ------
.. doxygenstruct:: migraphx::target .. doxygenstruct:: migraphx::target
:members:
:undoc-members:
program program
------- -------
.. doxygenstruct:: migraphx::program_parameter_shapes .. doxygenstruct:: migraphx::program_parameter_shapes
:members:
:undoc-members:
.. doxygenstruct:: migraphx::program_parameters .. doxygenstruct:: migraphx::program_parameters
:members:
:undoc-members:
.. doxygenstruct:: migraphx_compile_options .. doxygenstruct:: migraphx_compile_options
:members:
:undoc-members:
.. doxygenstruct:: migraphx::program .. doxygenstruct:: migraphx::program
:members:
:undoc-members:
quantize quantize
-------- --------
.. doxygenstruct:: migraphx::quantize_op_names .. doxygenstruct:: migraphx::quantize_op_names
:members:
:undoc-members:
.. doxygenfunction:: migraphx::quantize_fp16(const program&) .. doxygenfunction:: migraphx::quantize_fp16(const program&)
.. doxygenfunction:: migraphx::quantize_fp16(const program&, const quantize_op_names&) .. doxygenfunction:: migraphx::quantize_fp16(const program&, const quantize_op_names&)
.. doxygenstruct:: migraphx::quantize_int8_options .. doxygenstruct:: migraphx::quantize_int8_options
:members:
:undoc-members:
.. doxygenfunction:: migraphx::quantize_int8 .. doxygenfunction::migraphx::quantize_int8
parse_onnx parse_onnx
---------- ----------
.. doxygenstruct:: migraphx::onnx_options .. doxygenstruct:: migraphx::onnx_options
:members:
:undoc-members:
.. doxygenfunction:: migraphx::parse_onnx(const char *) .. doxygenfunction:: migraphx::parse_onnx(const char *)
...@@ -63,16 +83,18 @@ parse_onnx ...@@ -63,16 +83,18 @@ parse_onnx
load load
---- ----
.. doxygenstruct:: migraphx_file_options .. doxygenstruct:: migraphx::file_options
:members:
:undoc-members:
.. doxygenfunction:: migraphx::load(const char *) .. doxygenfunction:: migraphx::load(const char *)
.. doxygenfunction:: migraphx::load(const char *, migraphx_file_options) .. doxygenfunction:: migraphx::load(const char *, const file_options&)
save save
---- ----
.. doxygenfunction:: migraphx::save(const program&, const char *) .. doxygenfunction:: migraphx::save(const program&, const char *)
.. doxygenfunction:: migraphx::save(const program&, const char *, migraphx_file_options) .. doxygenfunction:: migraphx::save(const program&, const char *, const file_options&)
...@@ -95,7 +95,7 @@ shape ...@@ -95,7 +95,7 @@ shape
:rtype: bool :rtype: bool
dynamic_dimension dynamic_dimension
-------- -----------------
.. py:class:: dynamic_dimension(min, max, optimals) .. py:class:: dynamic_dimension(min, max, optimals)
...@@ -326,7 +326,7 @@ op ...@@ -326,7 +326,7 @@ op
parse_onnx parse_onnx
---------- ----------
.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false, max_loop_iterations=10) .. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false, max_loop_iterations=10, limit_max_iterations=65535)
Load and parse an onnx file. Load and parse an onnx file.
...@@ -337,7 +337,8 @@ parse_onnx ...@@ -337,7 +337,8 @@ parse_onnx
:param list[dynamic_dimension] map_dyn_input_dims: Explicitly specify the dynamic_dimensions of an input. :param list[dynamic_dimension] map_dyn_input_dims: Explicitly specify the dynamic_dimensions of an input.
:param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found. :param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found.
:param str print_program_on_error: Print program if an error occurs. :param str print_program_on_error: Print program if an error occurs.
:param int max_loop_iterations: Maximum iteration number for the loop operator. :param int max_loop_iterations: Maximum iteration number for the loop operator if trip count is not set.
:param int limit_max_iterations: Maximum iteration limit for the loop operator.
:rtype: program :rtype: program
parse_tf parse_tf
......
...@@ -6,4 +6,5 @@ This directory contains examples of common use cases for MIGraphX. ...@@ -6,4 +6,5 @@ This directory contains examples of common use cases for MIGraphX.
## Examples: ## Examples:
- [MIGraphX usage and utilities](./migraphx) - [MIGraphX usage and utilities](./migraphx)
- [Vision inference examples](./vision) - [Vision inference examples](./vision)
- [Natural language inference examples](./nlp) - [Natural language inference examples](./nlp)
\ No newline at end of file - [Diffusion inference examples](./diffusion)
# Diffusion Inference Examples
- [Python Stable Diffusion 2.1](./python_stable_diffusion_21)
# Stable Diffusion 2.1
This version was tested with [rocm 5.7](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/rocm-5.7.0) revision.
## Jupyter notebook
There is a dedicated step-by-step notebook. See [sd21.ipynb](./sd21.ipynb)
## Console application
To run the console application, follow these steps below.
Setup python environment
```bash
# this will require the python venv to installed (e.g. apt install python3.8-venv)
python3 -m venv sd_venv
. sd_venv/bin/activate
```
Install dependencies
```bash
pip install -r requirements.txt
```
Use MIGraphX Python Module
```bash
export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH
```
Get models with optimum
```bash
optimum-cli export onnx --model stabilityai/stable-diffusion-2-1 models/sd21-onnx
```
*Note: `models/sd21-onnx` will be used in the scripts.*
Run the text-to-image script with the following example prompt and seed:
```bash
python txt2img.py --prompt "a photograph of an astronaut riding a horse" --seed 13 --output astro_horse.jpg
```
*Note: The first run will compile the models and cache them to make subsequent runs faster.*
The result should look like this:
![example_output.jpg](./example_output.jpg)
## Gradio application
Note: requires `Console application` to work
Install gradio dependencies
```bash
pip install -r gradio_requirements.txt
```
Usage
```bash
python gradio_app.py
```
This will load the models (which can take several minutes), and when the setup is ready, starts a server on `http://127.0.0.1:7860`.
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from txt2img import StableDiffusionMGX
import gradio as gr
def main():
# Note: This will load the models, which can take several minutes
sd = StableDiffusionMGX()
def gr_wrapper(prompt, negative_prompt, steps, seed, scale):
result = sd.run(str(prompt), str(negative_prompt), int(steps),
int(seed), float(scale))
return StableDiffusionMGX.convert_to_rgb_image(result)
demo = gr.Interface(
gr_wrapper,
[
gr.Textbox(value="a photograph of an astronaut riding a horse",
label="Prompt"),
gr.Textbox(value="", label="Negative prompt (Optional)"),
gr.Slider(1, 100, step=1, value=20, label="Number of steps"),
gr.Textbox(value=13, label="Random seed"),
gr.Slider(1, 20, step=0.1, value=7.0, label="Guidance scale"),
],
"image",
)
demo.launch()
if __name__ == "__main__":
main()
#!/bin/bash
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -23,34 +21,5 @@ ...@@ -23,34 +21,5 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
-f requirements.txt
set -e gradio
\ No newline at end of file
if [ -z "$ONNX_HOME" ]
then
# The onnx library uses ONNX_HOME, by default if it doesn't exist
# the path of " ~/.onnx " is used
ONNX_HOME=$HOME/.onnx
fi
model_dir=$ONNX_HOME/models
tmp_dir=$ONNX_HOME/tmp/
mkdir -p $model_dir
mkdir -p $tmp_dir
models="bvlc_alexnet \
densenet121 \
inception_v2 \
shufflenet \
vgg19 \
zfnet512"
for name in $models
do
curl https://download.onnxruntime.ai/onnx/models/$name.tar.gz --output $tmp_dir/$name.tar.gz
tar -xzvf $tmp_dir/$name.tar.gz --directory $model_dir && rm $tmp_dir/$name.tar.gz
done
# CI jobs can run as a different user then the docker image builder.
# Allow read/write access to the models
chmod 777 $model_dir
...@@ -21,5 +21,7 @@ ...@@ -21,5 +21,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
accelerate
numpy==1.21.6 diffusers
\ No newline at end of file optimum[onnxruntime]
transformers
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# The MIT License (MIT)\n",
"#\n",
"# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.\n",
"#\n",
"# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
"# of this software and associated documentation files (the 'Software'), to deal\n",
"# in the Software without restriction, including without limitation the rights\n",
"# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
"# copies of the Software, and to permit persons to whom the Software is\n",
"# furnished to do so, subject to the following conditions:\n",
"#\n",
"# The above copyright notice and this permission notice shall be included in\n",
"# all copies or substantial portions of the Software.\n",
"#\n",
"# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
"# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
"# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
"# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
"# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
"# THE SOFTWARE."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stable Diffusion 2.1\n",
"\n",
"The following example will show how to run `Stable Diffusion 2.1` with `MIGraphX`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the required dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies\n",
"!pip install optimum[onnxruntime] transformers diffusers accelerate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use optimum to generate the onnx files."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export models\n",
"!optimum-cli export onnx --model stabilityai/stable-diffusion-2-1 models/sd21-onnx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now it is time to load these models with python.\n",
"\n",
"First, we make sure that MIGraphX module is found in the python path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"mgx_lib_path = \"/opt/rocm/lib/\" # or \"/code/AMDMIGraphX/build/lib/\"\n",
"if mgx_lib_path not in sys.path:\n",
" sys.path.append(mgx_lib_path)\n",
"import migraphx as mgx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, a helper method to load and cache the models.\n",
"\n",
"This will use the `models/sd21-onnx` path. If you changed it, make sure to update here as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"# helper for model loading\n",
"def load_mgx_model(name, shapes):\n",
" file = f\"models/sd21-onnx/{name}/model\"\n",
" print(f\"Loading {name} model from {file}\")\n",
" if os.path.isfile(f\"{file}.mxr\"):\n",
" print(f\"Found mxr, loading it...\")\n",
" model = mgx.load(f\"{file}.mxr\", format=\"msgpack\")\n",
" elif os.path.isfile(f\"{file}.onnx\"):\n",
" print(f\"Parsing from onnx file...\")\n",
" model = mgx.parse_onnx(f\"{file}.onnx\", map_input_dims=shapes)\n",
" model.compile(mgx.get_target(\"gpu\"))\n",
" print(f\"Saving {name} model to mxr file...\")\n",
" mgx.save(model, f\"{file}.mxr\", format=\"msgpack\")\n",
" else:\n",
" print(f\"No {name} model found. Please verify the path is correct and re-try, or re-download model.\")\n",
" os.exit(1)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With that, we can load the models. This could take several minutes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_encoder = load_mgx_model(\"text_encoder\", {\"input_ids\": [1, 77]})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"unet = load_mgx_model(\n",
" \"unet\", {\n",
" \"sample\": [1, 4, 64, 64],\n",
" \"encoder_hidden_states\": [1, 77, 1024],\n",
" \"timestep\": [1],\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vae = load_mgx_model(\"vae_decoder\", {\"latent_sample\": [1, 4, 64, 64]})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the remaining packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from diffusers import EulerDiscreteScheduler\n",
"from transformers import CLIPTokenizer\n",
"import torch\n",
"import numpy as np\n",
"from tqdm.auto import tqdm\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Time to load the scheduler and tokenizer from the original source."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_id = \"stabilityai/stable-diffusion-2-1\"\n",
"scheduler = EulerDiscreteScheduler.from_pretrained(model_id,\n",
" subfolder=\"scheduler\")\n",
"tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder=\"tokenizer\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will define all the steps one by one, to make the last step short and simple.\n",
"\n",
"The first step will be to tokenize the user prompt. It will make a `(1, 77)` shaped `input_ids`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize(input):\n",
" return tokenizer([input],\n",
" padding=\"max_length\",\n",
" max_length=tokenizer.model_max_length,\n",
" truncation=True,\n",
" return_tensors=\"np\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_tk = tokenize(\"test tokenizer to see the tokens\")\n",
"test_tk.input_ids.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We run the tokenized prompt through the `Text Encoder` model. It expects the `(1, 77)` data as `int32`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"text_encoder.get_parameter_shapes()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_embeddings(input):\n",
" return np.array(\n",
" text_encoder.run({\"input_ids\": input.input_ids.astype(np.int32)\n",
" })[0]).astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_emb = get_embeddings(tokenize(\"test tokenizer to see the tokens\"))\n",
"test_emb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The other input of the model is latent representation (pure noise). It will be transformed into a 512x512 image later.\n",
"The last input will be the timestep."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_latents(seed):\n",
" return torch.randn(\n",
" (1, 4, 64, 64),\n",
" generator=torch.manual_seed(seed),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_latents = generate_latents(42)\n",
"latents.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we add two helpers to access and convert from torch to numpy with the proper datatype."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_scaled_sample(latents, t):\n",
" return scheduler.scale_model_input(latents, t).numpy().astype(np.float32)\n",
"\n",
"\n",
"def get_timestep(t):\n",
" return np.atleast_1d(t.numpy().astype(np.int64)) # convert 0D -> 1D"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The UNet model will be run in a loop. It will predict the noise residual."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"unet.get_parameter_shapes()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def denoise(sample, embeddings, timestep):\n",
" return np.array(\n",
" unet.run({\n",
" \"sample\": sample,\n",
" \"encoder_hidden_states\": embeddings,\n",
" \"timestep\": timestep\n",
" })[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Helpers to do the classifier-free guidance and computing the previous noisy sample."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def perform_guidance(noise_pred_uncond, noise_pred_text, scale):\n",
" return noise_pred_uncond + scale * (noise_pred_text - noise_pred_uncond)\n",
"\n",
"def compute_previous(noise_pred, t, latents):\n",
" # compute the previous noisy sample x_t -> x_t-1\n",
" return scheduler.step(noise_pred, t, latents).prev_sample\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Scale and decode the image latents with VAE."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def scale_denoised(latents):\n",
" return 1 / 0.18215 * latents\n",
"\n",
"\n",
"def decode(latents):\n",
" return np.array(\n",
" vae.run({\"latent_sample\": latents.numpy().astype(np.float32)})[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And lastly, we need to convert it to an image to display or save."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def convert_to_rgb_image(image):\n",
" image = np.clip(image / 2 + 0.5, 0, 1)\n",
" image = np.transpose(image, (0, 2, 3, 1))\n",
" images = (image * 255).round().astype(\"uint8\")\n",
" return Image.fromarray(images[0])\n",
"\n",
"def save_image(pil_image, filename=\"output.png\"):\n",
" pil_image.save(filename, format=\"png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Feel free to play around with these params."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"a photograph of an astronaut riding a horse\"\n",
"negative_prompt = \"\"\n",
"steps = 20\n",
"seed = 13\n",
"scale = 7.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now, to put everything together and run the whole pipeline:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"scheduler.set_timesteps(steps)\n",
"\n",
"text_input, uncond_input = tokenize(prompt), tokenize(negative_prompt)\n",
"text_embeddings, uncond_embeddings = get_embeddings(\n",
" text_input), get_embeddings(uncond_input)\n",
"latents = generate_latents(seed) * scheduler.init_noise_sigma\n",
"\n",
"for t in tqdm(scheduler.timesteps):\n",
" sample = get_scaled_sample(latents, t)\n",
" timestep = get_timestep(t)\n",
"\n",
" noise_pred_uncond = denoise(sample, uncond_embeddings, timestep)\n",
" noise_pred_text = denoise(sample, text_embeddings, timestep)\n",
"\n",
" noise_pred = perform_guidance(noise_pred_uncond, noise_pred_text, scale)\n",
" latents = compute_previous(torch.from_numpy(noise_pred), t, latents)\n",
"\n",
"latents = scale_denoised(latents)\n",
"result = decode(latents)\n",
"image = convert_to_rgb_image(result)\n",
"\n",
"# show the image\n",
"image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you like the generated image, save it with the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_image(image, \"output.png\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sd_venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the 'Software'), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from argparse import ArgumentParser
from diffusers import EulerDiscreteScheduler
from transformers import CLIPTokenizer
from PIL import Image
import migraphx as mgx
import numpy as np
import os
import torch
import time
from functools import wraps
# measurement helper
def measure(fn):
@wraps(fn)
def measure_ms(*args, **kwargs):
start_time = time.perf_counter_ns()
result = fn(*args, **kwargs)
end_time = time.perf_counter_ns()
print(f"Elapsed time: {(end_time - start_time) * 1e-6:.4f} ms\n")
return result
return measure_ms
def get_args():
parser = ArgumentParser()
parser.add_argument(
"-s",
"--seed",
type=int,
default=42,
help="Random seed",
)
parser.add_argument(
"-t",
"--steps",
type=int,
default=20,
help="Number of steps",
)
parser.add_argument(
"-p",
"--prompt",
type=str,
required=True,
help="Prompt",
)
parser.add_argument(
"-n",
"--negative-prompt",
type=str,
default="",
help="Negative prompt",
)
parser.add_argument(
"--scale",
type=float,
default=7.0,
help="Guidance scale",
)
parser.add_argument(
"-o",
"--output",
type=str,
default=None,
help="Output name",
)
return parser.parse_args()
class StableDiffusionMGX():
def __init__(self):
model_id = "stabilityai/stable-diffusion-2-1"
print(f"Using {model_id}")
print("Creating EulerDiscreteScheduler scheduler")
self.scheduler = EulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler")
print("Creating CLIPTokenizer tokenizer...")
self.tokenizer = CLIPTokenizer.from_pretrained(model_id,
subfolder="tokenizer")
print("Load models...")
self.vae = StableDiffusionMGX.load_mgx_model(
"vae_decoder", {"latent_sample": [1, 4, 64, 64]})
self.text_encoder = StableDiffusionMGX.load_mgx_model(
"text_encoder", {"input_ids": [1, 77]})
self.unet = StableDiffusionMGX.load_mgx_model(
"unet", {
"sample": [1, 4, 64, 64],
"encoder_hidden_states": [1, 77, 1024],
"timestep": [1],
})
def run(self, prompt, negative_prompt, steps, seed, scale):
# need to set this for each run
self.scheduler.set_timesteps(steps)
print("Tokenizing prompt...")
text_input = self.tokenize(prompt)
print("Creating text embeddings for prompt...")
text_embeddings = self.get_embeddings(text_input)
print("Tokenizing negative prompt...")
uncond_input = self.tokenize(negative_prompt)
print("Creating text embeddings for negative prompt...")
uncond_embeddings = self.get_embeddings(uncond_input)
print(
f"Creating random input data ({1}x{4}x{64}x{64}) (latents) with seed={seed}..."
)
latents = torch.randn((1, 4, 64, 64),
generator=torch.manual_seed(seed))
print("Apply initial noise sigma\n")
latents = latents * self.scheduler.init_noise_sigma
print("Running denoising loop...")
for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
print("Scale denoised result...")
latents = 1 / 0.18215 * latents
print("Decode denoised result...")
image = self.decode(latents)
return image
@staticmethod
@measure
def load_mgx_model(name, shapes):
file = f"models/sd21-onnx/{name}/model"
print(f"Loading {name} model from {file}")
if os.path.isfile(f"{file}.mxr"):
print("Found mxr, loading it...")
model = mgx.load(f"{file}.mxr", format="msgpack")
elif os.path.isfile(f"{file}.onnx"):
print("Parsing from onnx file...")
model = mgx.parse_onnx(f"{file}.onnx", map_input_dims=shapes)
model.compile(mgx.get_target("gpu"))
print(f"Saving {name} model to mxr file...")
mgx.save(model, f"{file}.mxr", format="msgpack")
else:
print(f"No {name} model found. Please download it and re-try.")
os.exit(1)
return model
@measure
def tokenize(self, input):
return self.tokenizer([input],
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np")
@measure
def get_embeddings(self, input):
return np.array(
self.text_encoder.run(
{"input_ids":
input.input_ids.astype(np.int32)})[0]).astype(np.float32)
@staticmethod
def convert_to_rgb_image(image):
image = np.clip(image / 2 + 0.5, 0, 1)
image = np.transpose(image, (0, 2, 3, 1))
images = (image * 255).round().astype("uint8")
return Image.fromarray(images[0])
@staticmethod
def save_image(pil_image, filename="output.png"):
pil_image.save(filename)
@measure
def denoise_step(self, text_embeddings, uncond_embeddings, latents, t,
scale):
sample = self.scheduler.scale_model_input(latents,
t).numpy().astype(np.float32)
timestep = np.atleast_1d(t.numpy().astype(
np.int64)) # convert 0D -> 1D
noise_pred_uncond = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": uncond_embeddings,
"timestep": timestep
})[0])
noise_pred_text = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": text_embeddings,
"timestep": timestep
})[0])
# perform guidance
noise_pred = noise_pred_uncond + scale * (noise_pred_text -
noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
return self.scheduler.step(torch.from_numpy(noise_pred), t,
latents).prev_sample
@measure
def decode(self, latents):
return np.array(
self.vae.run({"latent_sample":
latents.numpy().astype(np.float32)})[0])
if __name__ == "__main__":
args = get_args()
sd = StableDiffusionMGX()
result = sd.run(args.prompt, args.negative_prompt, args.steps, args.seed,
args.scale)
print("Convert result to rgb image...")
image = StableDiffusionMGX.convert_to_rgb_image(result)
filename = args.output if args.output else f"output_s{args.seed}_t{args.steps}.png"
StableDiffusionMGX.save_image(image, args.output)
print(f"Image saved to {filename}")
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
#define MIGRAPHX_MIOPEN_ASSERT(x) (assert((x) == miopenStatusSuccess)) #define MIGRAPHX_MIOPEN_ASSERT(x) (assert((x) == miopenStatusSuccess))
#define MIGRAPHX_HIP_ASSERT(x) (assert((x) == hipSuccess)) #define MIGRAPHX_HIP_ASSERT(x) (assert((x) == hipSuccess))
inline miopenTensorDescriptor_t make_miopen_tensor(const migraphx::shape& s, bool pack = false) inline miopenTensorDescriptor_t make_miopen_tensor(const migraphx::shape& s)
{ {
miopenTensorDescriptor_t t; miopenTensorDescriptor_t t;
MIGRAPHX_MIOPEN_ASSERT(miopenCreateTensorDescriptor(&t)); MIGRAPHX_MIOPEN_ASSERT(miopenCreateTensorDescriptor(&t));
...@@ -49,23 +49,9 @@ inline miopenTensorDescriptor_t make_miopen_tensor(const migraphx::shape& s, boo ...@@ -49,23 +49,9 @@ inline miopenTensorDescriptor_t make_miopen_tensor(const migraphx::shape& s, boo
else if(s.type() == migraphx_shape_int32_type) else if(s.type() == migraphx_shape_int32_type)
d = miopenInt32; d = miopenInt32;
else if(s.type() == migraphx_shape_int8_type) else if(s.type() == migraphx_shape_int8_type)
{ d = miopenInt8;
if(pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
}
else else
{
throw("MAKE_TENSOR: unsupported type"); throw("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t, d, s_lens.size(), lens.data(), strides.data()); miopenSetTensorDescriptor(t, d, s_lens.size(), lens.data(), strides.data());
return t; return t;
} }
......
...@@ -149,9 +149,6 @@ gpu::gelu ...@@ -149,9 +149,6 @@ gpu::gelu
gpu::gelu_new gpu::gelu_new
gpu::gemm gpu::gemm
gpu::greater gpu::greater
gpu::int8_conv_pack
gpu::int8_gemm_pack_a
gpu::int8_gemm_pack_b
gpu::layernorm gpu::layernorm
gpu::leaky_relu gpu::leaky_relu
gpu::less gpu::less
......
...@@ -21,12 +21,12 @@ ...@@ -21,12 +21,12 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0 nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212 live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.6.0 ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@507bb94ce7873786486d296ec81d2eadaab49003 -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@a6880f1e6daec99876cd6a4820fbc69c57216401 -DBUILD_FAT_LIBROCKCOMPILER=On
\ No newline at end of file
...@@ -28,9 +28,9 @@ include(ROCMInstallTargets) ...@@ -28,9 +28,9 @@ include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers) include(ROCMPackageConfigHelpers)
include(RegisterOp) include(RegisterOp)
include(CheckCXXLinkerFlag) include(CheckCXXLinkerFlag)
include(CheckCXXSourceCompiles)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp adjust_allocation.cpp
analyze_streams.cpp analyze_streams.cpp
apply_alpha_beta.cpp apply_alpha_beta.cpp
...@@ -104,6 +104,12 @@ add_library(migraphx ...@@ -104,6 +104,12 @@ add_library(migraphx
value.cpp value.cpp
verify_args.cpp verify_args.cpp
) )
if(WIN32)
# Due to compilation crashing, we need to use type-erased matchers on Windows.
target_compile_definitions(migraphx PUBLIC MIGRAPHX_USE_TYPE_ERASED_MATCHERS=1)
endif()
configure_file(version.h.in include/migraphx/version.h) configure_file(version.h.in include/migraphx/version.h)
rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION})
function(register_migraphx_ops) function(register_migraphx_ops)
...@@ -155,6 +161,7 @@ register_migraphx_ops( ...@@ -155,6 +161,7 @@ register_migraphx_ops(
identity identity
if_op if_op
im2col im2col
isinf
isnan isnan
layout layout
leaky_relu leaky_relu
...@@ -174,6 +181,7 @@ register_migraphx_ops( ...@@ -174,6 +181,7 @@ register_migraphx_ops(
mul mul
multibroadcast multibroadcast
multinomial multinomial
nearbyint
neg neg
nonmaxsuppression nonmaxsuppression
nonzero nonzero
...@@ -204,7 +212,6 @@ register_migraphx_ops( ...@@ -204,7 +212,6 @@ register_migraphx_ops(
rnn_last_hs_output rnn_last_hs_output
rnn_var_sl_last_output rnn_var_sl_last_output
roialign roialign
round
rsqrt rsqrt
run_on_target run_on_target
scalar scalar
...@@ -214,6 +221,8 @@ register_migraphx_ops( ...@@ -214,6 +221,8 @@ register_migraphx_ops(
scatternd_add scatternd_add
scatternd_mul scatternd_mul
scatternd_none scatternd_none
scatternd_max
scatternd_min
select_module select_module
sigmoid sigmoid
sign sign
...@@ -232,6 +241,7 @@ register_migraphx_ops( ...@@ -232,6 +241,7 @@ register_migraphx_ops(
transpose transpose
unary_not unary_not
undefined undefined
unique
unknown unknown
unsqueeze unsqueeze
where where
...@@ -246,24 +256,68 @@ rocm_install_targets( ...@@ -246,24 +256,68 @@ rocm_install_targets(
${CMAKE_CURRENT_BINARY_DIR}/include ${CMAKE_CURRENT_BINARY_DIR}/include
) )
if(NOT WIN32)
check_cxx_linker_flag(-lstdc++fs HAS_LIB_STD_FILESYSTEM) check_cxx_linker_flag(-lstdc++fs HAS_LIB_STD_FILESYSTEM)
if(HAS_LIB_STD_FILESYSTEM) if(HAS_LIB_STD_FILESYSTEM)
target_link_libraries(migraphx PRIVATE -lstdc++fs) target_link_libraries(migraphx PRIVATE -lstdc++fs)
endif()
target_link_libraries(migraphx PRIVATE -ldl)
endif() endif()
target_link_libraries(migraphx PRIVATE -ldl)
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_link_libraries(migraphx PUBLIC Threads::Threads) target_link_libraries(migraphx PUBLIC Threads::Threads)
function(check_execution_par RESULT)
set(CMAKE_REQUIRED_LIBRARIES ${ARGN})
set(CMAKE_REQUIRED_FLAGS)
if(NOT MSVC)
set(CMAKE_REQUIRED_FLAGS "-std=c++17")
endif()
string(MD5 _flags_hash "${CMAKE_REQUIRED_FLAGS} ${CMAKE_REQUIRED_LIBRARIES}")
set(_source "
#include <execution>
int main() {
int* i = nullptr;
std::sort(std::execution::par, i, i);
}
")
check_cxx_source_compiles("${_source}" _has_execution_${_flags_hash})
set(${RESULT} ${_has_execution_${_flags_hash}} PARENT_SCOPE)
endfunction()
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT Off)
find_package(TBB QUIET)
if(TBB_FOUND)
check_execution_par(TBB_HAS_EXECUTION_PAR TBB::tbb)
if(TBB_HAS_EXECUTION_PAR)
list(APPEND PACKAGE_DEPENDS PACKAGE TBB)
target_link_libraries(migraphx PUBLIC TBB::tbb)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
message(STATUS "Using TBB for parallel execution")
endif()
else()
check_execution_par(HAS_EXECUTION_PAR)
if(HAS_EXECUTION_PAR)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
endif()
endif()
option(MIGRAPHX_HAS_EXECUTORS "C++ supports parallel executors" ${MIGRAPHX_HAS_EXECUTORS_DEFAULT})
if(MIGRAPHX_HAS_EXECUTORS)
message("Parallel STL enabled")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_HAS_EXECUTORS=1)
else()
message("Parallel STL disabled")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_HAS_EXECUTORS=0)
endif()
find_package(nlohmann_json 3.8.0 REQUIRED) find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json) target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
migraphx_generate_export_header(migraphx) migraphx_generate_export_header(migraphx)
find_package(PkgConfig) find_package(SQLite3 REQUIRED)
pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3) target_link_libraries(migraphx PRIVATE SQLite::SQLite3)
target_link_libraries(migraphx PRIVATE PkgConfig::SQLITE3)
find_package(msgpackc-cxx QUIET) find_package(msgpackc-cxx QUIET)
if(NOT msgpackc-cxx_FOUND) if(NOT msgpackc-cxx_FOUND)
...@@ -275,8 +329,6 @@ target_link_libraries(migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>) ...@@ -275,8 +329,6 @@ target_link_libraries(migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>)
add_library(migraphx_all_targets INTERFACE) add_library(migraphx_all_targets INTERFACE)
set(PACKAGE_DEPENDS)
add_subdirectory(api) add_subdirectory(api)
add_subdirectory(driver) add_subdirectory(driver)
add_subdirectory(onnx) add_subdirectory(onnx)
......
...@@ -164,6 +164,11 @@ void set_default_loop_iterations(onnx_options& options, int64_t value) ...@@ -164,6 +164,11 @@ void set_default_loop_iterations(onnx_options& options, int64_t value)
options.max_loop_iterations = value; options.max_loop_iterations = value;
} }
void set_limit_loop_iterations(onnx_options& options, int64_t value)
{
options.limit_max_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; } void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; } void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
...@@ -1904,6 +1909,17 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o ...@@ -1904,6 +1909,17 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_onnx_options_set_limit_loop_iterations(migraphx_onnx_options_t onnx_options, int64_t value)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_limit_loop_iterations((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options) extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{ {
auto api_error_result = migraphx::try_([&] { destroy((file_options)); }); auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
......
...@@ -44,7 +44,8 @@ ...@@ -44,7 +44,8 @@
m(int32_type, int32_t) \ m(int32_type, int32_t) \
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on // clang-format on
#ifdef __cplusplus #ifdef __cplusplus
...@@ -514,6 +515,9 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_dyn_dim_valu ...@@ -514,6 +515,9 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_dyn_dim_valu
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_loop_iterations( MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_loop_iterations(
migraphx_onnx_options_t onnx_options, int64_t value); migraphx_onnx_options_t onnx_options, int64_t value);
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_limit_loop_iterations(
migraphx_onnx_options_t onnx_options, int64_t value);
MIGRAPHX_C_EXPORT migraphx_status MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_destroy(migraphx_file_options_t file_options); migraphx_file_options_destroy(migraphx_file_options_t file_options);
......
...@@ -1321,6 +1321,12 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -1321,6 +1321,12 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
call(&migraphx_onnx_options_set_default_loop_iterations, this->get_handle_ptr(), value); call(&migraphx_onnx_options_set_default_loop_iterations, this->get_handle_ptr(), value);
} }
/// Set max iteration limit for the loop operator
void set_limit_loop_iterations(int64_t value)
{
call(&migraphx_onnx_options_set_limit_loop_iterations, this->get_handle_ptr(), value);
}
}; };
/// Parse an onnx file into a migraphx program /// Parse an onnx file into a migraphx program
......
...@@ -349,6 +349,11 @@ def onnx_options(h): ...@@ -349,6 +349,11 @@ def onnx_options(h):
api.params(value='int64_t'), api.params(value='int64_t'),
invoke='migraphx::set_default_loop_iterations($@)', invoke='migraphx::set_default_loop_iterations($@)',
) )
h.method(
'set_limit_loop_iterations',
api.params(value='int64_t'),
invoke='migraphx::set_limit_loop_iterations($@)',
)
@auto_handle() @auto_handle()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment