Commit 187888bb authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into...

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into split_multiple_same_dyn_dim
parents bab722b3 57f734a5
......@@ -465,7 +465,7 @@ jobs:
- name: Upload code coverage
if: "matrix.configuration == 'codecov'"
env:
CODECOV_TOKEN: "8545af1c-f90b-4345-92a5-0d075503ca56"
CODECOV_TOKEN: "f5d5a10b-3177-4c76-b25f-9b1c2f165e8b"
run: |
sudo apt-get install -y lcov
cd build
......
......@@ -22,6 +22,8 @@ def rocmtestnode(Map conf) {
def cmd = """
ulimit -c unlimited
echo "leak:dnnl::impl::malloc" > suppressions.txt
echo "leak:libtbb.so" >> suppressions.txt
cat suppressions.txt
export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt"
export MIGRAPHX_GPU_DEBUG=${gpu_debug}
export CXX=${compiler}
......
......@@ -6,4 +6,5 @@ This directory contains examples of common use cases for MIGraphX.
## Examples:
- [MIGraphX usage and utilities](./migraphx)
- [Vision inference examples](./vision)
- [Natural language inference examples](./nlp)
\ No newline at end of file
- [Natural language inference examples](./nlp)
- [Diffusion inference examples](./diffusion)
# Diffusion Inference Examples
- [Python Stable Diffusion 2.1](./python_stable_diffusion_21)
# Stable Diffusion 2.1
This version was tested with [rocm 5.7](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/rocm-5.7.0) revision.
## Jupyter notebook
There is a dedicated step-by-step notebook. See [sd21.ipynb](./sd21.ipynb)
## Console application
To run the console application, follow these steps below.
Setup python environment
```bash
# this will require the python venv to installed (e.g. apt install python3.8-venv)
python3 -m venv sd_venv
. sd_venv/bin/activate
```
Install dependencies
```bash
pip install -r requirements.txt
```
Use MIGraphX Python Module
```bash
export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH
```
Get models with optimum
```bash
optimum-cli export onnx --model stabilityai/stable-diffusion-2-1 models/sd21-onnx
```
*Note: `models/sd21-onnx` will be used in the scripts.*
Run the text-to-image script with the following example prompt and seed:
```bash
python txt2img.py --prompt "a photograph of an astronaut riding a horse" --seed 13 --output astro_horse.jpg
```
*Note: The first run will compile the models and cache them to make subsequent runs faster.*
The result should look like this:
![example_output.jpg](./example_output.jpg)
## Gradio application
Note: requires `Console application` to work
Install gradio dependencies
```bash
pip install -r gradio_requirements.txt
```
Usage
```bash
python gradio_app.py
```
This will load the models (which can take several minutes), and when the setup is ready, starts a server on `http://127.0.0.1:7860`.
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from txt2img import StableDiffusionMGX
import gradio as gr
def main():
# Note: This will load the models, which can take several minutes
sd = StableDiffusionMGX()
def gr_wrapper(prompt, negative_prompt, steps, seed, scale):
result = sd.run(str(prompt), str(negative_prompt), int(steps),
int(seed), float(scale))
return StableDiffusionMGX.convert_to_rgb_image(result)
demo = gr.Interface(
gr_wrapper,
[
gr.Textbox(value="a photograph of an astronaut riding a horse",
label="Prompt"),
gr.Textbox(value="", label="Negative prompt (Optional)"),
gr.Slider(1, 100, step=1, value=20, label="Number of steps"),
gr.Textbox(value=13, label="Random seed"),
gr.Slider(1, 20, step=0.1, value=7.0, label="Guidance scale"),
],
"image",
)
demo.launch()
if __name__ == "__main__":
main()
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
-f requirements.txt
gradio
\ No newline at end of file
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
accelerate
diffusers
optimum[onnxruntime]
transformers
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# The MIT License (MIT)\n",
"#\n",
"# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.\n",
"#\n",
"# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
"# of this software and associated documentation files (the 'Software'), to deal\n",
"# in the Software without restriction, including without limitation the rights\n",
"# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
"# copies of the Software, and to permit persons to whom the Software is\n",
"# furnished to do so, subject to the following conditions:\n",
"#\n",
"# The above copyright notice and this permission notice shall be included in\n",
"# all copies or substantial portions of the Software.\n",
"#\n",
"# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
"# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
"# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
"# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
"# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
"# THE SOFTWARE."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stable Diffusion 2.1\n",
"\n",
"The following example will show how to run `Stable Diffusion 2.1` with `MIGraphX`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the required dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies\n",
"!pip install optimum[onnxruntime] transformers diffusers accelerate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use optimum to generate the onnx files."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export models\n",
"!optimum-cli export onnx --model stabilityai/stable-diffusion-2-1 models/sd21-onnx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now it is time to load these models with python.\n",
"\n",
"First, we make sure that MIGraphX module is found in the python path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"mgx_lib_path = \"/opt/rocm/lib/\" # or \"/code/AMDMIGraphX/build/lib/\"\n",
"if mgx_lib_path not in sys.path:\n",
" sys.path.append(mgx_lib_path)\n",
"import migraphx as mgx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, a helper method to load and cache the models.\n",
"\n",
"This will use the `models/sd21-onnx` path. If you changed it, make sure to update here as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"# helper for model loading\n",
"def load_mgx_model(name, shapes):\n",
" file = f\"models/sd21-onnx/{name}/model\"\n",
" print(f\"Loading {name} model from {file}\")\n",
" if os.path.isfile(f\"{file}.mxr\"):\n",
" print(f\"Found mxr, loading it...\")\n",
" model = mgx.load(f\"{file}.mxr\", format=\"msgpack\")\n",
" elif os.path.isfile(f\"{file}.onnx\"):\n",
" print(f\"Parsing from onnx file...\")\n",
" model = mgx.parse_onnx(f\"{file}.onnx\", map_input_dims=shapes)\n",
" model.compile(mgx.get_target(\"gpu\"))\n",
" print(f\"Saving {name} model to mxr file...\")\n",
" mgx.save(model, f\"{file}.mxr\", format=\"msgpack\")\n",
" else:\n",
" print(f\"No {name} model found. Please verify the path is correct and re-try, or re-download model.\")\n",
" os.exit(1)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With that, we can load the models. This could take several minutes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_encoder = load_mgx_model(\"text_encoder\", {\"input_ids\": [1, 77]})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"unet = load_mgx_model(\n",
" \"unet\", {\n",
" \"sample\": [1, 4, 64, 64],\n",
" \"encoder_hidden_states\": [1, 77, 1024],\n",
" \"timestep\": [1],\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vae = load_mgx_model(\"vae_decoder\", {\"latent_sample\": [1, 4, 64, 64]})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the remaining packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from diffusers import EulerDiscreteScheduler\n",
"from transformers import CLIPTokenizer\n",
"import torch\n",
"import numpy as np\n",
"from tqdm.auto import tqdm\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Time to load the scheduler and tokenizer from the original source."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_id = \"stabilityai/stable-diffusion-2-1\"\n",
"scheduler = EulerDiscreteScheduler.from_pretrained(model_id,\n",
" subfolder=\"scheduler\")\n",
"tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder=\"tokenizer\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will define all the steps one by one, to make the last step short and simple.\n",
"\n",
"The first step will be to tokenize the user prompt. It will make a `(1, 77)` shaped `input_ids`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize(input):\n",
" return tokenizer([input],\n",
" padding=\"max_length\",\n",
" max_length=tokenizer.model_max_length,\n",
" truncation=True,\n",
" return_tensors=\"np\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_tk = tokenize(\"test tokenizer to see the tokens\")\n",
"test_tk.input_ids.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We run the tokenized prompt through the `Text Encoder` model. It expects the `(1, 77)` data as `int32`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"text_encoder.get_parameter_shapes()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_embeddings(input):\n",
" return np.array(\n",
" text_encoder.run({\"input_ids\": input.input_ids.astype(np.int32)\n",
" })[0]).astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_emb = get_embeddings(tokenize(\"test tokenizer to see the tokens\"))\n",
"test_emb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The other input of the model is latent representation (pure noise). It will be transformed into a 512x512 image later.\n",
"The last input will be the timestep."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_latents(seed):\n",
" return torch.randn(\n",
" (1, 4, 64, 64),\n",
" generator=torch.manual_seed(seed),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_latents = generate_latents(42)\n",
"latents.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we add two helpers to access and convert from torch to numpy with the proper datatype."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_scaled_sample(latents, t):\n",
" return scheduler.scale_model_input(latents, t).numpy().astype(np.float32)\n",
"\n",
"\n",
"def get_timestep(t):\n",
" return np.atleast_1d(t.numpy().astype(np.int64)) # convert 0D -> 1D"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The UNet model will be run in a loop. It will predict the noise residual."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"unet.get_parameter_shapes()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def denoise(sample, embeddings, timestep):\n",
" return np.array(\n",
" unet.run({\n",
" \"sample\": sample,\n",
" \"encoder_hidden_states\": embeddings,\n",
" \"timestep\": timestep\n",
" })[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Helpers to do the classifier-free guidance and computing the previous noisy sample."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def perform_guidance(noise_pred_uncond, noise_pred_text, scale):\n",
" return noise_pred_uncond + scale * (noise_pred_text - noise_pred_uncond)\n",
"\n",
"def compute_previous(noise_pred, t, latents):\n",
" # compute the previous noisy sample x_t -> x_t-1\n",
" return scheduler.step(noise_pred, t, latents).prev_sample\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Scale and decode the image latents with VAE."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def scale_denoised(latents):\n",
" return 1 / 0.18215 * latents\n",
"\n",
"\n",
"def decode(latents):\n",
" return np.array(\n",
" vae.run({\"latent_sample\": latents.numpy().astype(np.float32)})[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And lastly, we need to convert it to an image to display or save."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def convert_to_rgb_image(image):\n",
" image = np.clip(image / 2 + 0.5, 0, 1)\n",
" image = np.transpose(image, (0, 2, 3, 1))\n",
" images = (image * 255).round().astype(\"uint8\")\n",
" return Image.fromarray(images[0])\n",
"\n",
"def save_image(pil_image, filename=\"output.png\"):\n",
" pil_image.save(filename, format=\"png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Feel free to play around with these params."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"a photograph of an astronaut riding a horse\"\n",
"negative_prompt = \"\"\n",
"steps = 20\n",
"seed = 13\n",
"scale = 7.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now, to put everything together and run the whole pipeline:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"scheduler.set_timesteps(steps)\n",
"\n",
"text_input, uncond_input = tokenize(prompt), tokenize(negative_prompt)\n",
"text_embeddings, uncond_embeddings = get_embeddings(\n",
" text_input), get_embeddings(uncond_input)\n",
"latents = generate_latents(seed) * scheduler.init_noise_sigma\n",
"\n",
"for t in tqdm(scheduler.timesteps):\n",
" sample = get_scaled_sample(latents, t)\n",
" timestep = get_timestep(t)\n",
"\n",
" noise_pred_uncond = denoise(sample, uncond_embeddings, timestep)\n",
" noise_pred_text = denoise(sample, text_embeddings, timestep)\n",
"\n",
" noise_pred = perform_guidance(noise_pred_uncond, noise_pred_text, scale)\n",
" latents = compute_previous(torch.from_numpy(noise_pred), t, latents)\n",
"\n",
"latents = scale_denoised(latents)\n",
"result = decode(latents)\n",
"image = convert_to_rgb_image(result)\n",
"\n",
"# show the image\n",
"image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you like the generated image, save it with the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_image(image, \"output.png\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sd_venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the 'Software'), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from argparse import ArgumentParser
from diffusers import EulerDiscreteScheduler
from transformers import CLIPTokenizer
from PIL import Image
import migraphx as mgx
import numpy as np
import os
import torch
import time
from functools import wraps
# measurement helper
def measure(fn):
@wraps(fn)
def measure_ms(*args, **kwargs):
start_time = time.perf_counter_ns()
result = fn(*args, **kwargs)
end_time = time.perf_counter_ns()
print(f"Elapsed time: {(end_time - start_time) * 1e-6:.4f} ms\n")
return result
return measure_ms
def get_args():
parser = ArgumentParser()
parser.add_argument(
"-s",
"--seed",
type=int,
default=42,
help="Random seed",
)
parser.add_argument(
"-t",
"--steps",
type=int,
default=20,
help="Number of steps",
)
parser.add_argument(
"-p",
"--prompt",
type=str,
required=True,
help="Prompt",
)
parser.add_argument(
"-n",
"--negative-prompt",
type=str,
default="",
help="Negative prompt",
)
parser.add_argument(
"--scale",
type=float,
default=7.0,
help="Guidance scale",
)
parser.add_argument(
"-o",
"--output",
type=str,
default=None,
help="Output name",
)
return parser.parse_args()
class StableDiffusionMGX():
def __init__(self):
model_id = "stabilityai/stable-diffusion-2-1"
print(f"Using {model_id}")
print("Creating EulerDiscreteScheduler scheduler")
self.scheduler = EulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler")
print("Creating CLIPTokenizer tokenizer...")
self.tokenizer = CLIPTokenizer.from_pretrained(model_id,
subfolder="tokenizer")
print("Load models...")
self.vae = StableDiffusionMGX.load_mgx_model(
"vae_decoder", {"latent_sample": [1, 4, 64, 64]})
self.text_encoder = StableDiffusionMGX.load_mgx_model(
"text_encoder", {"input_ids": [1, 77]})
self.unet = StableDiffusionMGX.load_mgx_model(
"unet", {
"sample": [1, 4, 64, 64],
"encoder_hidden_states": [1, 77, 1024],
"timestep": [1],
})
def run(self, prompt, negative_prompt, steps, seed, scale):
# need to set this for each run
self.scheduler.set_timesteps(steps)
print("Tokenizing prompt...")
text_input = self.tokenize(prompt)
print("Creating text embeddings for prompt...")
text_embeddings = self.get_embeddings(text_input)
print("Tokenizing negative prompt...")
uncond_input = self.tokenize(negative_prompt)
print("Creating text embeddings for negative prompt...")
uncond_embeddings = self.get_embeddings(uncond_input)
print(
f"Creating random input data ({1}x{4}x{64}x{64}) (latents) with seed={seed}..."
)
latents = torch.randn((1, 4, 64, 64),
generator=torch.manual_seed(seed))
print("Apply initial noise sigma\n")
latents = latents * self.scheduler.init_noise_sigma
print("Running denoising loop...")
for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
print("Scale denoised result...")
latents = 1 / 0.18215 * latents
print("Decode denoised result...")
image = self.decode(latents)
return image
@staticmethod
@measure
def load_mgx_model(name, shapes):
file = f"models/sd21-onnx/{name}/model"
print(f"Loading {name} model from {file}")
if os.path.isfile(f"{file}.mxr"):
print("Found mxr, loading it...")
model = mgx.load(f"{file}.mxr", format="msgpack")
elif os.path.isfile(f"{file}.onnx"):
print("Parsing from onnx file...")
model = mgx.parse_onnx(f"{file}.onnx", map_input_dims=shapes)
model.compile(mgx.get_target("gpu"))
print(f"Saving {name} model to mxr file...")
mgx.save(model, f"{file}.mxr", format="msgpack")
else:
print(f"No {name} model found. Please download it and re-try.")
os.exit(1)
return model
@measure
def tokenize(self, input):
return self.tokenizer([input],
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np")
@measure
def get_embeddings(self, input):
return np.array(
self.text_encoder.run(
{"input_ids":
input.input_ids.astype(np.int32)})[0]).astype(np.float32)
@staticmethod
def convert_to_rgb_image(image):
image = np.clip(image / 2 + 0.5, 0, 1)
image = np.transpose(image, (0, 2, 3, 1))
images = (image * 255).round().astype("uint8")
return Image.fromarray(images[0])
@staticmethod
def save_image(pil_image, filename="output.png"):
pil_image.save(filename)
@measure
def denoise_step(self, text_embeddings, uncond_embeddings, latents, t,
scale):
sample = self.scheduler.scale_model_input(latents,
t).numpy().astype(np.float32)
timestep = np.atleast_1d(t.numpy().astype(
np.int64)) # convert 0D -> 1D
noise_pred_uncond = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": uncond_embeddings,
"timestep": timestep
})[0])
noise_pred_text = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": text_embeddings,
"timestep": timestep
})[0])
# perform guidance
noise_pred = noise_pred_uncond + scale * (noise_pred_text -
noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
return self.scheduler.step(torch.from_numpy(noise_pred), t,
latents).prev_sample
@measure
def decode(self, latents):
return np.array(
self.vae.run({"latent_sample":
latents.numpy().astype(np.float32)})[0])
if __name__ == "__main__":
args = get_args()
sd = StableDiffusionMGX()
result = sd.run(args.prompt, args.negative_prompt, args.steps, args.seed,
args.scale)
print("Convert result to rgb image...")
image = StableDiffusionMGX.convert_to_rgb_image(result)
filename = args.output if args.output else f"output_s{args.seed}_t{args.steps}.png"
StableDiffusionMGX.save_image(image, args.output)
print(f"Image saved to {filename}")
......@@ -28,6 +28,7 @@ include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers)
include(RegisterOp)
include(CheckCXXLinkerFlag)
include(CheckCXXSourceCompiles)
add_library(migraphx
adjust_allocation.cpp
......@@ -263,6 +264,50 @@ endif()
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_link_libraries(migraphx PUBLIC Threads::Threads)
function(check_execution_par RESULT)
set(CMAKE_REQUIRED_LIBRARIES ${ARGN})
set(CMAKE_REQUIRED_FLAGS)
if(NOT MSVC)
set(CMAKE_REQUIRED_FLAGS "-std=c++17")
endif()
string(MD5 _flags_hash "${CMAKE_REQUIRED_FLAGS} ${CMAKE_REQUIRED_LIBRARIES}")
set(_source "
#include <execution>
int main() {
int* i = nullptr;
std::sort(std::execution::par, i, i);
}
")
check_cxx_source_compiles("${_source}" _has_execution_${_flags_hash})
set(${RESULT} ${_has_execution_${_flags_hash}} PARENT_SCOPE)
endfunction()
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT Off)
find_package(TBB QUIET)
if(TBB_FOUND)
check_execution_par(TBB_HAS_EXECUTION_PAR TBB::tbb)
if(TBB_HAS_EXECUTION_PAR)
target_link_libraries(migraphx PUBLIC TBB::tbb)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
message(STATUS "Using TBB for parallel execution")
endif()
else()
check_execution_par(HAS_EXECUTION_PAR)
if(HAS_EXECUTION_PAR)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
endif()
endif()
option(MIGRAPHX_HAS_EXECUTORS "C++ supports parallel executors" ${MIGRAPHX_HAS_EXECUTORS_DEFAULT})
if(MIGRAPHX_HAS_EXECUTORS)
message("Parallel STL enabled")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_HAS_EXECUTORS=1)
else()
message("Parallel STL disabled")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_HAS_EXECUTORS=0)
endif()
find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
migraphx_generate_export_header(migraphx)
......
......@@ -29,6 +29,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -95,11 +96,11 @@ struct binary : op_name<Derived>
{
argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
par_transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
return result;
}
......
......@@ -70,7 +70,8 @@ struct pooling
// 2 smaller than the input tensor rank (NCHW layout)
std::vector<std::size_t> lengths = {1, 1};
// Dilations are not supported at this time.
// Spacing between the elements of the pooling kernel. Must be the same ndim as lengths.
std::vector<std::size_t> dilations = {1, 1};
// ceiling mode is a flag affecting output size
// or equivalently, placements of the pooling kernel.
......@@ -99,6 +100,7 @@ struct pooling
f(self.padding_mode, "padding_mode"),
f(self.stride, "stride"),
f(self.lengths, "lengths"),
f(self.dilations, "dilations"),
f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order"),
f(self.dyn_global, "dyn_global"));
......@@ -112,14 +114,17 @@ struct pooling
return;
if((padding_mode != default_ and padding.size() != stride.size() and
(padding.size()) != stride.size() * 2) or
stride.size() != lengths.size())
stride.size() != lengths.size() or dilations.size() != lengths.size())
{
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
}
if(std::any_of(lengths.begin(), lengths.end(), [&](auto i) { return (i == 0); }) or
std::any_of(stride.begin(), stride.end(), [&](auto i) { return (i == 0); }))
const auto is_zero = [](auto el) { return el == 0; };
if(std::any_of(lengths.begin(), lengths.end(), is_zero) or
std::any_of(stride.begin(), stride.end(), is_zero) or
std::any_of(dilations.begin(), dilations.end(), is_zero))
{
MIGRAPHX_THROW("POOLING: size 0 pooling kernel or stride");
MIGRAPHX_THROW("POOLING: size 0 pooling kernel or stride or dilations");
}
// TODO: update lowering to run the reference
......@@ -142,6 +147,11 @@ struct pooling
value attributes() const { return {{"normalize_padding", "padding"}}; }
inline std::size_t dilate_dim(std::size_t dim, std::size_t dilation) const
{
return 1 + dilation * (dim - 1);
}
std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens,
std::size_t kdims) const
{
......@@ -151,8 +161,9 @@ struct pooling
std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
std::size_t dilated_length = dilate_dim(lengths[i], dilations[i]);
std::size_t dim_size;
if(input_lens[i + 2] + padding_factor < lengths[i])
if(input_lens[i + 2] + padding_factor < dilated_length)
{
if(padding_mode == default_)
MIGRAPHX_THROW("POOLING: not enough padding for the given kernel size");
......@@ -162,7 +173,7 @@ struct pooling
}
else
{
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
dim_size = input_lens[i + 2] + padding_factor - dilated_length;
}
std::size_t len =
(ceil_mode)
......@@ -331,6 +342,7 @@ struct pooling
int start = static_cast<int>(idx_o[dim] * stride[d_2]) -
static_cast<int>(padding_vals[d_2]);
int end;
std::size_t dilated_kernel_dim = dilate_dim(kernel_dims[d_2], dilations[d_2]);
// NOLINT
if(count_include_pad and ceil_mode and (mode != pooling_mode::max))
{
......@@ -340,15 +352,14 @@ struct pooling
// padding. Clip out-of-bounds indexes but not padding.
// Check if this kernel extends beyond the padding at end of dimension
end = std::min(start + kernel_dims[d_2],
end = std::min(start + dilated_kernel_dim,
in_lens[dim] + static_cast<int>(padding_vals[d_2]));
}
else
{
// In non-ceiling mode, when
// count_include_pad is false, or for max pooling, clip off padding.
end = std::min(start + kernel_dims[d_2], in_lens[dim]);
start = std::max(start, 0);
end = std::min(start + dilated_kernel_dim, in_lens[dim]);
}
win_start.push_back(start);
if(end < start)
......@@ -366,6 +377,16 @@ struct pooling
// for each element in the window...
shape_for_each(win_shape, [&](const auto& idx_w) {
// Skip elements that belong to the dilated area
for(size_t axis = 0; axis < idx_w.size(); ++axis)
{
if(idx_w[axis] % dilations[axis])
{
pool_size -= 1;
return;
}
}
// the coordinates of this element
auto idx = idx_o;
......@@ -390,7 +411,15 @@ struct pooling
// this is a padding element. Padding locations
// don't contribute to average or max pooling total but can play in
// lpnorm pooling.
output_val = op(output_val, 0);
if(mode == pooling_mode::lpnorm)
{
output_val = op(output_val, op.template init<Type>());
}
if(mode == pooling_mode::average)
{
// Ignore padding
pool_size -= 1;
}
}
});
output[i] = Type(op.final(output_val, pool_size));
......
......@@ -31,6 +31,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -84,10 +85,10 @@ struct unary : op_name<Derived>
argument result{dyn_out.computed_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
par_transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
});
return result;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/config.hpp>
#if MIGRAPHX_HAS_EXECUTORS
#include <execution>
#else
#include <migraphx/simple_par_for.hpp>
#endif
#include <algorithm>
#include <mutex>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
struct exception_list
{
std::vector<std::exception_ptr> exceptions;
std::mutex m;
void add_exception()
{
std::lock_guard<std::mutex> guard(m);
exceptions.push_back(std::current_exception());
}
template <class F>
auto collect(F f)
{
return [f, this](auto&&... xs) {
try
{
f(std::forward<decltype(xs)>(xs)...);
}
catch(...)
{
this->add_exception();
}
};
}
void throw_if_exception() const
{
if(not exceptions.empty())
std::rethrow_exception(exceptions.front());
}
};
} // namespace detail
template <class InputIt, class OutputIt, class UnaryOperation>
OutputIt par_transform(InputIt first1, InputIt last1, OutputIt d_first, UnaryOperation unary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(std::execution::par, first1, last1, d_first, std::move(unary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = unary_op(first1[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt1, class InputIt2, class OutputIt, class BinaryOperation>
OutputIt par_transform(
InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt d_first, BinaryOperation binary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(
std::execution::par, first1, last1, first2, d_first, std::move(binary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = binary_op(first1[i], first2[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt, class UnaryFunction>
void par_for_each(InputIt first, InputIt last, UnaryFunction f)
{
#if MIGRAPHX_HAS_EXECUTORS
// Propagate the exception
detail::exception_list ex;
std::for_each(std::execution::par, first, last, ex.collect(std::move(f)));
ex.throw_if_exception();
#else
simple_par_for(last - first, [&](auto i) { f(first[i]); });
#endif
}
template <class... Ts>
auto par_copy_if(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::copy_if(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::copy_if(std::forward<Ts>(xs)...);
#endif
}
template <class... Ts>
auto par_sort(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::sort(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::sort(std::forward<Ts>(xs)...);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
......@@ -24,93 +24,23 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
#include <migraphx/par.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
void par_for(std::size_t n, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f);
using iterator = basic_iota_iterator<id, std::size_t>;
par_for_each(iterator{0, {}}, iterator{n, {}}, f);
}
template <class F>
void par_for(std::size_t n, F f)
void par_for(std::size_t n, std::size_t, F f)
{
const int min_grain = 8;
par_for(n, min_grain, f);
par_for(n, f);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -26,6 +26,7 @@
#include <string>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void simple_par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F>
void simple_par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
simple_par_for_impl(n, threadsize, f);
}
template <class F>
void simple_par_for(std::size_t n, F f)
{
const int min_grain = 8;
simple_par_for(n, min_grain, f);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -91,6 +91,14 @@ struct parse_pooling : op_parser<parse_pooling>
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
if(contains(info.attributes, "dilations"))
{
values["dilations"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilations"]));
check_attr_sizes(
kdims, values["dilations"].size(), "PARSE_POOLING: inconsistent dilations");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
......@@ -169,10 +177,15 @@ struct parse_pooling : op_parser<parse_pooling>
std::fill_n(values["stride"].begin(), kdims, 1);
}
if(values["dilations"].size() != kdims)
{
values["dilations"].resize(kdims);
std::fill_n(values["dilations"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
......@@ -189,11 +202,10 @@ struct parse_pooling : op_parser<parse_pooling>
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
std::vector<size_t>(in_shape.ndim() - 2, 1),
values["dilations"].to_vector<std::size_t>(),
in_shape.lens(),
paddings);
values["padding"] = paddings;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment