Commit 61edd67d authored by Sam Wu's avatar Sam Wu
Browse files

Merge branch 'develop' into doc-standard

parents a72c9e83 eafd55de
...@@ -136,12 +136,14 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> ...@@ -136,12 +136,14 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
} }
}, mlir_debug: rocmnode('mi100+') { cmake_build -> }, mlir_debug: rocmnode('mi100+') { cmake_build ->
stage('MLIR Debug') { stage('MLIR Debug') {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1']) { withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot']) {
def sanitizers = "undefined" def sanitizers = "undefined"
// Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS.
def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}"
def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr -fno-sanitize-recover=${sanitizers}"
def gpu_targets = getgputargets() def gpu_targets = getgputargets()
// Since the purpose of this run verify all things MLIR supports,
// enabling all possible types of offloads
cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags_cxx}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}' -DGPU_TARGETS='${gpu_targets}'") cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags_cxx}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}' -DGPU_TARGETS='${gpu_targets}'")
} }
} }
......
...@@ -4,13 +4,13 @@ Environment Variables ...@@ -4,13 +4,13 @@ Environment Variables
For parsing For parsing
--------------- ---------------
**MIGRAPHX_TRACE_ONNX_PARSER** .. envvar:: MIGRAPHX_TRACE_ONNX_PARSER
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print debugging traces for the onnx parser. Print debugging traces for the onnx parser.
Prints: initializers (if used), ONNX node operators, added MIGraphX instructions Prints: initializers (if used), ONNX node operators, added MIGraphX instructions
**MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT** .. envvar:: MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Disables the conversion from fp16 to fp32 for the InstanceNormalization ONNX operator that MIGX does as a workaround for accuracy issues with reduce_mean/variance. Disables the conversion from fp16 to fp32 for the InstanceNormalization ONNX operator that MIGX does as a workaround for accuracy issues with reduce_mean/variance.
...@@ -20,16 +20,16 @@ See ``parse_instancenorm.cpp`` for more details. ...@@ -20,16 +20,16 @@ See ``parse_instancenorm.cpp`` for more details.
Matchers Matchers
------------ ------------
**MIGRAPHX_TRACE_MATCHES** .. envvar:: MIGRAPHX_TRACE_MATCHES
Set to "1" to print the matcher that matches an instruction and the matched instruction. Set to "1" to print the matcher that matches an instruction and the matched instruction.
Set to "2" and use the ``MIGRAPHX_TRACE_MATHCES_FOR`` flag to filter out results. Set to "2" and use the ``MIGRAPHX_TRACE_MATHCES_FOR`` flag to filter out results.
**MIGRAPHX_TRACE_MATCHES_FOR** .. envvar:: MIGRAPHX_TRACE_MATCHES_FOR
Set to the name of any matcher and only traces for that matcher will be printed out. Set to the name of any matcher and only traces for that matcher will be printed out.
**MIGRAPHX_VALIDATE_MATCHES** .. envvar:: MIGRAPHX_VALIDATE_MATCHES
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Validate the module after finding the matches (runs ``module.validate()``). Validate the module after finding the matches (runs ``module.validate()``).
...@@ -37,7 +37,7 @@ Validate the module after finding the matches (runs ``module.validate()``). ...@@ -37,7 +37,7 @@ Validate the module after finding the matches (runs ``module.validate()``).
Program Execution Program Execution
--------------------- ---------------------
**MIGRAPHX_TRACE_EVAL** .. envvar:: MIGRAPHX_TRACE_EVAL
Set to "1", "2", or "3" to use. Set to "1", "2", or "3" to use.
"1" prints the instruction run and the time taken. "1" prints the instruction run and the time taken.
...@@ -48,7 +48,7 @@ Set to "1", "2", or "3" to use. ...@@ -48,7 +48,7 @@ Set to "1", "2", or "3" to use.
Program Verification Program Verification
------------------------ ------------------------
**MIGRAPHX_VERIFY_ENABLE_ALLCLOSE** .. envvar:: MIGRAPHX_VERIFY_ENABLE_ALLCLOSE
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Uses ``allclose`` with the given ``atol`` and ``rtol`` for verifying ranges with ``driver verify`` or the tests that use ``migraphx/verify.hpp``. Uses ``allclose`` with the given ``atol`` and ``rtol`` for verifying ranges with ``driver verify`` or the tests that use ``migraphx/verify.hpp``.
...@@ -57,76 +57,76 @@ Uses ``allclose`` with the given ``atol`` and ``rtol`` for verifying ranges with ...@@ -57,76 +57,76 @@ Uses ``allclose`` with the given ``atol`` and ``rtol`` for verifying ranges with
Pass debugging or Pass controls Pass debugging or Pass controls
----------------------------------- -----------------------------------
**MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS** .. envvar:: MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Debug print the instructions that have input ``contiguous`` instructions removed. Debug print the instructions that have input ``contiguous`` instructions removed.
**MIGRAPHX_DISABLE_POINTWISE_FUSION** .. envvar:: MIGRAPHX_DISABLE_POINTWISE_FUSION
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Disables the ``fuse_pointwise`` compile pass. Disables the ``fuse_pointwise`` compile pass.
**MIGRAPHX_DEBUG_MEMORY_COLORING** .. envvar:: MIGRAPHX_DEBUG_MEMORY_COLORING
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print debug statements for the ``memory_coloring`` pass. Print debug statements for the ``memory_coloring`` pass.
**MIGRAPHX_TRACE_SCHEDULE** .. envvar:: MIGRAPHX_TRACE_SCHEDULE
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print debug statements for the ``schedule`` pass. Print debug statements for the ``schedule`` pass.
**MIGRAPHX_TRACE_PROPAGATE_CONSTANT** .. envvar:: MIGRAPHX_TRACE_PROPAGATE_CONSTANT
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Traces instructions replaced with a constant. Traces instructions replaced with a constant.
**MIGRAPHX_INT8_QUANTIZATION_PARAMS** .. envvar:: MIGRAPHX_INT8_QUANTIZATION_PARAMS
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the quantization parameters in only the main module. Print the quantization parameters in only the main module.
**MIGRAPHX_DISABLE_DNNL_POST_OPS_WORKAROUND** .. envvar:: MIGRAPHX_DISABLE_DNNL_POST_OPS_WORKAROUND
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Disable the DNNL post ops workaround. Disable the DNNL post ops workaround.
**MIGRAPHX_DISABLE_MIOPEN_FUSION** .. envvar:: MIGRAPHX_DISABLE_MIOPEN_FUSION
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Disable MIOpen fusions. Disable MIOpen fusions.
**MIGRAPHX_DISABLE_SCHEDULE_PASS** .. envvar:: MIGRAPHX_DISABLE_SCHEDULE_PASS
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Disable the ``schedule`` pass. Disable the ``schedule`` pass.
**MIGRAPHX_DISABLE_REDUCE_FUSION** .. envvar:: MIGRAPHX_DISABLE_REDUCE_FUSION
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Disable the ``fuse_reduce`` pass. Disable the ``fuse_reduce`` pass.
**MIGRAPHX_ENABLE_NHWC** .. envvar:: MIGRAPHX_ENABLE_NHWC
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable the ``layout_nhwc`` pass. Enable the ``layout_nhwc`` pass.
**MIGRAPHX_ENABLE_CK** .. envvar:: MIGRAPHX_ENABLE_CK
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable using the Composable Kernels library. Enable using the Composable Kernels library.
Should be used in conjunction with ``MIGRAPHX_DISABLE_MLIR=1``. Should be used in conjunction with ``MIGRAPHX_DISABLE_MLIR=1``.
**MIGRAPHX_DISABLE_MLIR** .. envvar:: MIGRAPHX_DISABLE_MLIR*
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Disable using the rocMLIR library. Disable using the rocMLIR library.
**MIGRAPHX_ENABLE_EXTRA_MLIR** .. envvar:: MIGRAPHX_ENABLE_EXTRA_MLIR
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Enables additional opportunities to use MLIR that may improve performance. Enables additional opportunities to use MLIR that may improve performance.
**MIGRAPHX_COPY_LITERALS** .. envvar:: MIGRAPHX_COPY_LITERALS
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Use ``hip_copy_to_gpu`` with a new ``literal`` instruction rather than use ``hip_copy_literal{}``. Use ``hip_copy_to_gpu`` with a new ``literal`` instruction rather than use ``hip_copy_literal{}``.
...@@ -134,22 +134,22 @@ Use ``hip_copy_to_gpu`` with a new ``literal`` instruction rather than use ``hip ...@@ -134,22 +134,22 @@ Use ``hip_copy_to_gpu`` with a new ``literal`` instruction rather than use ``hip
Compilation traces Compilation traces
---------------------- ----------------------
**MIGRAPHX_TRACE_FINALIZE** .. envvar:: MIGRAPHX_TRACE_FINALIZE
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Debug print instructions during the ``module.finalize()`` step. Debug print instructions during the ``module.finalize()`` step.
**MIGRAPHX_TRACE_COMPILE** .. envvar:: MIGRAPHX_TRACE_COMPILE
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print trace information for the graph compilation process. Print trace information for the graph compilation process.
**MIGRAPHX_TRACE_PASSES** .. envvar:: MIGRAPHX_TRACE_PASSES
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the compile pass and the program after the pass. Print the compile pass and the program after the pass.
**MIGRAPHX_TIME_PASSES** .. envvar:: MIGRAPHX_TIME_PASSES
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Time the compile passes. Time the compile passes.
...@@ -158,77 +158,77 @@ Time the compile passes. ...@@ -158,77 +158,77 @@ Time the compile passes.
GPU Kernels JIT compilation debugging (applicable for both hiprtc and hipclang) GPU Kernels JIT compilation debugging (applicable for both hiprtc and hipclang)
----------------------------------------- -----------------------------------------
**MIGRAPHX_TRACE_CMD_EXECUTE** .. envvar:: MIGRAPHX_TRACE_CMD_EXECUTE
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print commands executed by the MIGraphX ``process``. Print commands executed by the MIGraphX ``process``.
**MIGRAPHX_TRACE_HIPRTC** .. envvar:: MIGRAPHX_TRACE_HIPRTC
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print HIPRTC options and C++ file executed. Print HIPRTC options and C++ file executed.
**MIGRAPHX_DEBUG_SAVE_TEMP_DIR** .. envvar:: MIGRAPHX_DEBUG_SAVE_TEMP_DIR
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Make it so the created temporary directories are not deleted. Make it so the created temporary directories are not deleted.
**MIGRAPHX_GPU_DEBUG** .. envvar:: MIGRAPHX_GPU_DEBUG
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Internally, this adds the option ``-DMIGRAPHX_DEBUG`` when compiling GPU kernels. It enables assertions and capture of source locations for the errors. Internally, this adds the option ``-DMIGRAPHX_DEBUG`` when compiling GPU kernels. It enables assertions and capture of source locations for the errors.
**MIGRAPHX_GPU_DEBUG_SYM** .. envvar:: MIGRAPHX_GPU_DEBUG_SYM
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Adds the option ``-g`` when compiling HIPRTC. Adds the option ``-g`` when compiling HIPRTC.
**MIGRAPHX_GPU_DUMP_SRC** .. envvar:: MIGRAPHX_GPU_DUMP_SRC
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Dump the HIPRTC source files compiled. Dump the HIPRTC source files compiled.
**MIGRAPHX_GPU_DUMP_ASM** .. envvar:: MIGRAPHX_GPU_DUMP_ASM
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Dump the hip-clang assembly. Dump the hip-clang assembly.
**MIGRAPHX_GPU_OPTIMIZE** .. envvar:: MIGRAPHX_GPU_OPTIMIZE
Set the optimization mode for GPU compile (``-O`` option). Set the optimization mode for GPU compile (``-O`` option).
Defaults to ``-O3``. Defaults to ``-O3``.
**MIGRAPHX_GPU_COMPILE_PARALLEL** .. envvar:: MIGRAPHX_GPU_COMPILE_PARALLEL
Set to the number of threads to use. Set to the number of threads to use.
Compile GPU code in parallel with the given number of threads. Compile GPU code in parallel with the given number of threads.
**MIGRAPHX_TRACE_NARY** .. envvar:: MIGRAPHX_TRACE_NARY
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the ``nary`` device functions used. Print the ``nary`` device functions used.
**MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS** .. envvar:: MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable HIPRTC workarounds for bugs in HIPRTC. Enable HIPRTC workarounds for bugs in HIPRTC.
**MIGRAPHX_USE_FAST_SOFTMAX** .. envvar:: MIGRAPHX_USE_FAST_SOFTMAX
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Use the fast softmax optimization. Use the fast softmax optimization.
**MIGRAPHX_ENABLE_NULL_STREAM** .. envvar:: MIGRAPHX_ENABLE_NULL_STREAM
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Allow using null stream for miopen and hipStream. Allow using null stream for miopen and hipStream.
**MIGRAPHX_NSTREAMS** .. envvar:: MIGRAPHX_NSTREAMS
Set to the number of streams to use. Set to the number of streams to use.
Defaults to 1. Defaults to 1.
**MIGRAPHX_TRACE_BENCHMARKING** .. envvar:: MIGRAPHX_TRACE_BENCHMARKING
Set to "1" to print benchmarching trace. Set to "1" to print benchmarching trace.
Set to "2" to print benchmarching trace with more detail. Set to "2" to print benchmarching trace with more detail.
...@@ -236,45 +236,49 @@ Set to "2" to print benchmarching trace with more detail. ...@@ -236,45 +236,49 @@ Set to "2" to print benchmarching trace with more detail.
MLIR vars MLIR vars
------------- -------------
**MIGRAPHX_TRACE_MLIR** .. envvar:: MIGRAPHX_TRACE_MLIR
Set to "1" to trace MLIR and print any failures. Set to "1" to trace MLIR and print any failures.
Set to "2" to additionally print all MLIR operations. Set to "2" to additionally print all MLIR operations.
**MIGRAPHX_MLIR_USE_SPECIFIC_OPS** .. envvar:: MIGRAPHX_MLIR_USE_SPECIFIC_OPS
Set to the name of the operations you want to always use MLIR regardless of GPU architecture. Set to the name of the operations you want to always use MLIR regardless of GPU architecture.
Accepts a list of operators separated by commas (ex: "fused", "convolution", "dot"). Accepts a list of operators separated by commas (ex: "fused", "convolution", "dot").
**MIGRAPHX_MLIR_TUNING_DB** .. envvar:: MIGRAPHX_MLIR_TUNING_DB
Set to the path of the MLIR tuning database to load. Set to the path of the MLIR tuning database to load.
**MIGRAPHX_MLIR_TUNING_CFG** .. envvar:: MIGRAPHX_MLIR_TUNING_CFG
Set to the path of the tuning configuration. Set to the path of the tuning configuration.
Appends to tuning cfg file that could be used with rocMLIR tuning scripts. Appends to tuning cfg file that could be used with rocMLIR tuning scripts.
**MIGRAPHX_MLIR_TUNE_EXHAUSTIVE** .. envvar:: MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Do exhaustive tuning for MLIR. Do exhaustive tuning for MLIR.
.. envvar:: MIGRAPHX_MLIR_TUNE_LIMIT
Set to an integer greater than 1.
Limits the number of solutions that MLIR will use for tuning.
CK vars CK vars
----------- -----------
**MIGRAPHX_LOG_CK_GEMM** .. envvar:: MIGRAPHX_LOG_CK_GEMM
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print Composable Kernels GEMM traces. Print Composable Kernels GEMM traces.
**MIGRAPHX_CK_DEBUG** .. envvar:: MIGRAPHX_CK_DEBUG
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Always add the ``-DMIGRAPHX_CK_CHECK=1`` for compiling Composable Kernels operators. Always add the ``-DMIGRAPHX_CK_CHECK=1`` for compiling Composable Kernels operators.
**MIGRAPHX_TUNE_CK** .. envvar:: MIGRAPHX_TUNE_CK
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Use tuning for Composable Kernels. Use tuning for Composable Kernels.
...@@ -282,19 +286,19 @@ Use tuning for Composable Kernels. ...@@ -282,19 +286,19 @@ Use tuning for Composable Kernels.
Testing Testing
------------ ------------
**MIGRAPHX_TRACE_TEST_COMPILE** .. envvar:: MIGRAPHX_TRACE_TEST_COMPILE
Set to the target that you want to trace the compilation of (ex. "gpu", "cpu"). Set to the target that you want to trace the compilation of (ex. "gpu", "cpu").
Prints the compile trace for the given target for the verify tests. Prints the compile trace for the given target for the verify tests.
This flag shouldn't be used in conjunction with ``MIGRAPHX_TRACE_COMPILE``. This flag shouldn't be used in conjunction with ``MIGRAPHX_TRACE_COMPILE``.
For the verify tests only use ``MIGRAPHX_TRACE_TEST_COMPILE``. For the verify tests only use ``MIGRAPHX_TRACE_TEST_COMPILE``.
**MIGRAPHX_TRACE_TEST** .. envvar:: MIGRAPHX_TRACE_TEST
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Prints the reference and target programs even if the verify passed successfully. Prints the reference and target programs even if the verify passed successfully.
**MIGRAPHX_DUMP_TEST** .. envvar:: MIGRAPHX_DUMP_TEST
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Dumps verify tests to ``.mxr`` files. Dumps verify tests to ``.mxr`` files.
rocm-docs-core==0.29.0 rocm-docs-core==0.30.0
...@@ -49,7 +49,7 @@ charset-normalizer==3.1.0 ...@@ -49,7 +49,7 @@ charset-normalizer==3.1.0
# via requests # via requests
click==8.1.3 click==8.1.3
# via sphinx-external-toc # via sphinx-external-toc
cryptography==41.0.4 cryptography==41.0.6
# via pyjwt # via pyjwt
deprecated==1.2.13 deprecated==1.2.13
# via pygithub # via pygithub
...@@ -121,7 +121,7 @@ requests==2.28.2 ...@@ -121,7 +121,7 @@ requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.29.0 rocm-docs-core==0.30.0
# via -r requirements.in # via -r requirements.in
smmap==5.0.0 smmap==5.0.0
# via gitdb # via gitdb
......
...@@ -6,4 +6,5 @@ This directory contains examples of common use cases for MIGraphX. ...@@ -6,4 +6,5 @@ This directory contains examples of common use cases for MIGraphX.
## Examples: ## Examples:
- [MIGraphX usage and utilities](./migraphx) - [MIGraphX usage and utilities](./migraphx)
- [Vision inference examples](./vision) - [Vision inference examples](./vision)
- [Natural language inference examples](./nlp) - [Natural language inference examples](./nlp)
\ No newline at end of file - [Diffusion inference examples](./diffusion)
# Diffusion Inference Examples
- [Python Stable Diffusion 2.1](./python_stable_diffusion_21)
# Stable Diffusion 2.1
This version was tested with [rocm 5.7](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/rocm-5.7.0) revision.
## Jupyter notebook
There is a dedicated step-by-step notebook. See [sd21.ipynb](./sd21.ipynb)
## Console application
To run the console application, follow these steps below.
Setup python environment
```bash
# this will require the python venv to installed (e.g. apt install python3.8-venv)
python3 -m venv sd_venv
. sd_venv/bin/activate
```
Install dependencies
```bash
pip install -r requirements.txt
```
Use MIGraphX Python Module
```bash
export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH
```
Get models with optimum
```bash
optimum-cli export onnx --model stabilityai/stable-diffusion-2-1 models/sd21-onnx
```
*Note: `models/sd21-onnx` will be used in the scripts.*
Run the text-to-image script with the following example prompt and seed:
```bash
python txt2img.py --prompt "a photograph of an astronaut riding a horse" --seed 13 --output astro_horse.jpg
```
*Note: The first run will compile the models and cache them to make subsequent runs faster.*
The result should look like this:
![example_output.jpg](./example_output.jpg)
## Gradio application
Note: requires `Console application` to work
Install gradio dependencies
```bash
pip install -r gradio_requirements.txt
```
Usage
```bash
python gradio_app.py
```
This will load the models (which can take several minutes), and when the setup is ready, starts a server on `http://127.0.0.1:7860`.
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from txt2img import StableDiffusionMGX
import gradio as gr
def main():
# Note: This will load the models, which can take several minutes
sd = StableDiffusionMGX()
def gr_wrapper(prompt, negative_prompt, steps, seed, scale):
result = sd.run(str(prompt), str(negative_prompt), int(steps),
int(seed), float(scale))
return StableDiffusionMGX.convert_to_rgb_image(result)
demo = gr.Interface(
gr_wrapper,
[
gr.Textbox(value="a photograph of an astronaut riding a horse",
label="Prompt"),
gr.Textbox(value="", label="Negative prompt (Optional)"),
gr.Slider(1, 100, step=1, value=20, label="Number of steps"),
gr.Textbox(value=13, label="Random seed"),
gr.Slider(1, 20, step=0.1, value=7.0, label="Guidance scale"),
],
"image",
)
demo.launch()
if __name__ == "__main__":
main()
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
-f requirements.txt
gradio
\ No newline at end of file
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
accelerate
diffusers
optimum[onnxruntime]
transformers
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# The MIT License (MIT)\n",
"#\n",
"# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.\n",
"#\n",
"# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
"# of this software and associated documentation files (the 'Software'), to deal\n",
"# in the Software without restriction, including without limitation the rights\n",
"# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
"# copies of the Software, and to permit persons to whom the Software is\n",
"# furnished to do so, subject to the following conditions:\n",
"#\n",
"# The above copyright notice and this permission notice shall be included in\n",
"# all copies or substantial portions of the Software.\n",
"#\n",
"# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
"# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
"# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
"# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
"# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
"# THE SOFTWARE."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stable Diffusion 2.1\n",
"\n",
"The following example will show how to run `Stable Diffusion 2.1` with `MIGraphX`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the required dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies\n",
"!pip install optimum[onnxruntime] transformers diffusers accelerate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use optimum to generate the onnx files."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export models\n",
"!optimum-cli export onnx --model stabilityai/stable-diffusion-2-1 models/sd21-onnx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now it is time to load these models with python.\n",
"\n",
"First, we make sure that MIGraphX module is found in the python path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"mgx_lib_path = \"/opt/rocm/lib/\" # or \"/code/AMDMIGraphX/build/lib/\"\n",
"if mgx_lib_path not in sys.path:\n",
" sys.path.append(mgx_lib_path)\n",
"import migraphx as mgx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, a helper method to load and cache the models.\n",
"\n",
"This will use the `models/sd21-onnx` path. If you changed it, make sure to update here as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"# helper for model loading\n",
"def load_mgx_model(name, shapes):\n",
" file = f\"models/sd21-onnx/{name}/model\"\n",
" print(f\"Loading {name} model from {file}\")\n",
" if os.path.isfile(f\"{file}.mxr\"):\n",
" print(f\"Found mxr, loading it...\")\n",
" model = mgx.load(f\"{file}.mxr\", format=\"msgpack\")\n",
" elif os.path.isfile(f\"{file}.onnx\"):\n",
" print(f\"Parsing from onnx file...\")\n",
" model = mgx.parse_onnx(f\"{file}.onnx\", map_input_dims=shapes)\n",
" model.compile(mgx.get_target(\"gpu\"))\n",
" print(f\"Saving {name} model to mxr file...\")\n",
" mgx.save(model, f\"{file}.mxr\", format=\"msgpack\")\n",
" else:\n",
" print(f\"No {name} model found. Please verify the path is correct and re-try, or re-download model.\")\n",
" os.exit(1)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With that, we can load the models. This could take several minutes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_encoder = load_mgx_model(\"text_encoder\", {\"input_ids\": [1, 77]})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"unet = load_mgx_model(\n",
" \"unet\", {\n",
" \"sample\": [1, 4, 64, 64],\n",
" \"encoder_hidden_states\": [1, 77, 1024],\n",
" \"timestep\": [1],\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vae = load_mgx_model(\"vae_decoder\", {\"latent_sample\": [1, 4, 64, 64]})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the remaining packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from diffusers import EulerDiscreteScheduler\n",
"from transformers import CLIPTokenizer\n",
"import torch\n",
"import numpy as np\n",
"from tqdm.auto import tqdm\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Time to load the scheduler and tokenizer from the original source."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_id = \"stabilityai/stable-diffusion-2-1\"\n",
"scheduler = EulerDiscreteScheduler.from_pretrained(model_id,\n",
" subfolder=\"scheduler\")\n",
"tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder=\"tokenizer\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will define all the steps one by one, to make the last step short and simple.\n",
"\n",
"The first step will be to tokenize the user prompt. It will make a `(1, 77)` shaped `input_ids`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize(input):\n",
" return tokenizer([input],\n",
" padding=\"max_length\",\n",
" max_length=tokenizer.model_max_length,\n",
" truncation=True,\n",
" return_tensors=\"np\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_tk = tokenize(\"test tokenizer to see the tokens\")\n",
"test_tk.input_ids.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We run the tokenized prompt through the `Text Encoder` model. It expects the `(1, 77)` data as `int32`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"text_encoder.get_parameter_shapes()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_embeddings(input):\n",
" return np.array(\n",
" text_encoder.run({\"input_ids\": input.input_ids.astype(np.int32)\n",
" })[0]).astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_emb = get_embeddings(tokenize(\"test tokenizer to see the tokens\"))\n",
"test_emb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The other input of the model is latent representation (pure noise). It will be transformed into a 512x512 image later.\n",
"The last input will be the timestep."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_latents(seed):\n",
" return torch.randn(\n",
" (1, 4, 64, 64),\n",
" generator=torch.manual_seed(seed),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"test_latents = generate_latents(42)\n",
"latents.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we add two helpers to access and convert from torch to numpy with the proper datatype."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_scaled_sample(latents, t):\n",
" return scheduler.scale_model_input(latents, t).numpy().astype(np.float32)\n",
"\n",
"\n",
"def get_timestep(t):\n",
" return np.atleast_1d(t.numpy().astype(np.int64)) # convert 0D -> 1D"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The UNet model will be run in a loop. It will predict the noise residual."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional\n",
"unet.get_parameter_shapes()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def denoise(sample, embeddings, timestep):\n",
" return np.array(\n",
" unet.run({\n",
" \"sample\": sample,\n",
" \"encoder_hidden_states\": embeddings,\n",
" \"timestep\": timestep\n",
" })[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Helpers to do the classifier-free guidance and computing the previous noisy sample."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def perform_guidance(noise_pred_uncond, noise_pred_text, scale):\n",
" return noise_pred_uncond + scale * (noise_pred_text - noise_pred_uncond)\n",
"\n",
"def compute_previous(noise_pred, t, latents):\n",
" # compute the previous noisy sample x_t -> x_t-1\n",
" return scheduler.step(noise_pred, t, latents).prev_sample\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Scale and decode the image latents with VAE."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def scale_denoised(latents):\n",
" return 1 / 0.18215 * latents\n",
"\n",
"\n",
"def decode(latents):\n",
" return np.array(\n",
" vae.run({\"latent_sample\": latents.numpy().astype(np.float32)})[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And lastly, we need to convert it to an image to display or save."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def convert_to_rgb_image(image):\n",
" image = np.clip(image / 2 + 0.5, 0, 1)\n",
" image = np.transpose(image, (0, 2, 3, 1))\n",
" images = (image * 255).round().astype(\"uint8\")\n",
" return Image.fromarray(images[0])\n",
"\n",
"def save_image(pil_image, filename=\"output.png\"):\n",
" pil_image.save(filename, format=\"png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Feel free to play around with these params."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"a photograph of an astronaut riding a horse\"\n",
"negative_prompt = \"\"\n",
"steps = 20\n",
"seed = 13\n",
"scale = 7.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now, to put everything together and run the whole pipeline:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"scheduler.set_timesteps(steps)\n",
"\n",
"text_input, uncond_input = tokenize(prompt), tokenize(negative_prompt)\n",
"text_embeddings, uncond_embeddings = get_embeddings(\n",
" text_input), get_embeddings(uncond_input)\n",
"latents = generate_latents(seed) * scheduler.init_noise_sigma\n",
"\n",
"for t in tqdm(scheduler.timesteps):\n",
" sample = get_scaled_sample(latents, t)\n",
" timestep = get_timestep(t)\n",
"\n",
" noise_pred_uncond = denoise(sample, uncond_embeddings, timestep)\n",
" noise_pred_text = denoise(sample, text_embeddings, timestep)\n",
"\n",
" noise_pred = perform_guidance(noise_pred_uncond, noise_pred_text, scale)\n",
" latents = compute_previous(torch.from_numpy(noise_pred), t, latents)\n",
"\n",
"latents = scale_denoised(latents)\n",
"result = decode(latents)\n",
"image = convert_to_rgb_image(result)\n",
"\n",
"# show the image\n",
"image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you like the generated image, save it with the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_image(image, \"output.png\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sd_venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the 'Software'), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from argparse import ArgumentParser
from diffusers import EulerDiscreteScheduler
from transformers import CLIPTokenizer
from PIL import Image
import migraphx as mgx
import numpy as np
import os
import torch
import time
from functools import wraps
# measurement helper
def measure(fn):
@wraps(fn)
def measure_ms(*args, **kwargs):
start_time = time.perf_counter_ns()
result = fn(*args, **kwargs)
end_time = time.perf_counter_ns()
print(f"Elapsed time: {(end_time - start_time) * 1e-6:.4f} ms\n")
return result
return measure_ms
def get_args():
parser = ArgumentParser()
parser.add_argument(
"-s",
"--seed",
type=int,
default=42,
help="Random seed",
)
parser.add_argument(
"-t",
"--steps",
type=int,
default=20,
help="Number of steps",
)
parser.add_argument(
"-p",
"--prompt",
type=str,
required=True,
help="Prompt",
)
parser.add_argument(
"-n",
"--negative-prompt",
type=str,
default="",
help="Negative prompt",
)
parser.add_argument(
"--scale",
type=float,
default=7.0,
help="Guidance scale",
)
parser.add_argument(
"-o",
"--output",
type=str,
default=None,
help="Output name",
)
return parser.parse_args()
class StableDiffusionMGX():
def __init__(self):
model_id = "stabilityai/stable-diffusion-2-1"
print(f"Using {model_id}")
print("Creating EulerDiscreteScheduler scheduler")
self.scheduler = EulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler")
print("Creating CLIPTokenizer tokenizer...")
self.tokenizer = CLIPTokenizer.from_pretrained(model_id,
subfolder="tokenizer")
print("Load models...")
self.vae = StableDiffusionMGX.load_mgx_model(
"vae_decoder", {"latent_sample": [1, 4, 64, 64]})
self.text_encoder = StableDiffusionMGX.load_mgx_model(
"text_encoder", {"input_ids": [1, 77]})
self.unet = StableDiffusionMGX.load_mgx_model(
"unet", {
"sample": [1, 4, 64, 64],
"encoder_hidden_states": [1, 77, 1024],
"timestep": [1],
})
def run(self, prompt, negative_prompt, steps, seed, scale):
# need to set this for each run
self.scheduler.set_timesteps(steps)
print("Tokenizing prompt...")
text_input = self.tokenize(prompt)
print("Creating text embeddings for prompt...")
text_embeddings = self.get_embeddings(text_input)
print("Tokenizing negative prompt...")
uncond_input = self.tokenize(negative_prompt)
print("Creating text embeddings for negative prompt...")
uncond_embeddings = self.get_embeddings(uncond_input)
print(
f"Creating random input data ({1}x{4}x{64}x{64}) (latents) with seed={seed}..."
)
latents = torch.randn((1, 4, 64, 64),
generator=torch.manual_seed(seed))
print("Apply initial noise sigma\n")
latents = latents * self.scheduler.init_noise_sigma
print("Running denoising loop...")
for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
print("Scale denoised result...")
latents = 1 / 0.18215 * latents
print("Decode denoised result...")
image = self.decode(latents)
return image
@staticmethod
@measure
def load_mgx_model(name, shapes):
file = f"models/sd21-onnx/{name}/model"
print(f"Loading {name} model from {file}")
if os.path.isfile(f"{file}.mxr"):
print("Found mxr, loading it...")
model = mgx.load(f"{file}.mxr", format="msgpack")
elif os.path.isfile(f"{file}.onnx"):
print("Parsing from onnx file...")
model = mgx.parse_onnx(f"{file}.onnx", map_input_dims=shapes)
model.compile(mgx.get_target("gpu"))
print(f"Saving {name} model to mxr file...")
mgx.save(model, f"{file}.mxr", format="msgpack")
else:
print(f"No {name} model found. Please download it and re-try.")
os.exit(1)
return model
@measure
def tokenize(self, input):
return self.tokenizer([input],
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np")
@measure
def get_embeddings(self, input):
return np.array(
self.text_encoder.run(
{"input_ids":
input.input_ids.astype(np.int32)})[0]).astype(np.float32)
@staticmethod
def convert_to_rgb_image(image):
image = np.clip(image / 2 + 0.5, 0, 1)
image = np.transpose(image, (0, 2, 3, 1))
images = (image * 255).round().astype("uint8")
return Image.fromarray(images[0])
@staticmethod
def save_image(pil_image, filename="output.png"):
pil_image.save(filename)
@measure
def denoise_step(self, text_embeddings, uncond_embeddings, latents, t,
scale):
sample = self.scheduler.scale_model_input(latents,
t).numpy().astype(np.float32)
timestep = np.atleast_1d(t.numpy().astype(
np.int64)) # convert 0D -> 1D
noise_pred_uncond = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": uncond_embeddings,
"timestep": timestep
})[0])
noise_pred_text = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": text_embeddings,
"timestep": timestep
})[0])
# perform guidance
noise_pred = noise_pred_uncond + scale * (noise_pred_text -
noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
return self.scheduler.step(torch.from_numpy(noise_pred), t,
latents).prev_sample
@measure
def decode(self, latents):
return np.array(
self.vae.run({"latent_sample":
latents.numpy().astype(np.float32)})[0])
if __name__ == "__main__":
args = get_args()
sd = StableDiffusionMGX()
result = sd.run(args.prompt, args.negative_prompt, args.steps, args.seed,
args.scale)
print("Convert result to rgb image...")
image = StableDiffusionMGX.convert_to_rgb_image(result)
filename = args.output if args.output else f"output_s{args.seed}_t{args.steps}.png"
StableDiffusionMGX.save_image(image, args.output)
print(f"Image saved to {filename}")
...@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build ...@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@13f6c2a69cfe80a575c6b241ec7353d1e953cb12 -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@9e66e8050209f03349a41b6b497f0da2b285a53b -DBUILD_FAT_LIBROCKCOMPILER=On
...@@ -221,6 +221,8 @@ register_migraphx_ops( ...@@ -221,6 +221,8 @@ register_migraphx_ops(
scatternd_add scatternd_add
scatternd_mul scatternd_mul
scatternd_none scatternd_none
scatternd_max
scatternd_min
select_module select_module
sigmoid sigmoid
sign sign
...@@ -239,6 +241,7 @@ register_migraphx_ops( ...@@ -239,6 +241,7 @@ register_migraphx_ops(
transpose transpose
unary_not unary_not
undefined undefined
unique
unknown unknown
unsqueeze unsqueeze
where where
...@@ -288,6 +291,7 @@ find_package(TBB QUIET) ...@@ -288,6 +291,7 @@ find_package(TBB QUIET)
if(TBB_FOUND) if(TBB_FOUND)
check_execution_par(TBB_HAS_EXECUTION_PAR TBB::tbb) check_execution_par(TBB_HAS_EXECUTION_PAR TBB::tbb)
if(TBB_HAS_EXECUTION_PAR) if(TBB_HAS_EXECUTION_PAR)
list(APPEND PACKAGE_DEPENDS PACKAGE TBB)
target_link_libraries(migraphx PUBLIC TBB::tbb) target_link_libraries(migraphx PUBLIC TBB::tbb)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On) set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
message(STATUS "Using TBB for parallel execution") message(STATUS "Using TBB for parallel execution")
......
...@@ -21,10 +21,13 @@ ...@@ -21,10 +21,13 @@
* ************************************************************************ */ * ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP #define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#include <type_traits>
#if defined(__GNUC__) && !defined(__clang__) #if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif #endif
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) // NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
...@@ -32,7 +35,10 @@ ...@@ -32,7 +35,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <typename To, typename From> template <typename To,
typename From,
MIGRAPHX_REQUIRES(std::is_trivially_copyable<To>{} and
std::is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept inline constexpr To bit_cast(From fr) noexcept
{ {
static_assert(sizeof(To) == sizeof(From)); static_assert(sizeof(To) == sizeof(From));
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_max : scatternd_op<scatternd_max>
{
scatternd_max() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = std::max(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_min : scatternd_op<scatternd_min>
{
scatternd_min() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = std::min(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived> ...@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived>
auto k = indices_shape.lens().back(); auto k = indices_shape.lens().back();
auto q = indices_shape.ndim(); auto q = indices_shape.ndim();
auto r = dyn_out.computed_shape.ndim(); auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) { for(auto i = 0u; i < updates_shape.elements(); ++i)
{
auto updates_idx = updates_std.multi(i); auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0); std::vector<std::size_t> indices_idx(q, 0);
std::copy( std::copy(
...@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived> ...@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived>
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]); self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
}); }
}); });
}); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tune_axis.hpp>
#include <utility>
#include <map>
#include <limits>
#include <optional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// https://onnx.ai/onnx/operators/onnx__Unique.html
// The Onnx spec refers to numpy specification, used as a reference:
// https://numpy.org/doc/stable/reference/generated/numpy.unique.html
// Input : Given an array of elements : X.
// Output(s) :
// 1. Find the unique elements (Y) of input (X).
//
// There are three outputs in addition to the unique elements in Y:
// 2. the indices of the input array that give the unique values
// 3. the indices of the unique array that reconstruct the input array
// 4. the number of times each unique value comes up in the input array
// Optional Attribute: 'Sorted' = 1 for sorted; = 0 for unsorted.
// Onnx specification makes 'sorted' a default, while Numpy always sorts.
//
// Optional Attribute: 'Axis' is 'None' (default) or a valid int < rank(X).
// Negative values are allowed.
//
// Numpy has the following important note on Axis:
// ------------------------------------------------------------------
// When an axis is specified the subarrays indexed by the axis are
// sorted. This is done by making the specified axis the first
// dimension of the array (move the axis to the first dimension to
// keep the order of the other axes) and then flattening the subarrays
// in C order. The flattened subarrays are then viewed as a structured
// type with each element given a label, with the effect that we end
// up with a 1-D array of structured types that can be treated in the
// same way as any other 1-D array. The result is that the flattened
// subarrays are sorted in lexicographic order starting with the first
// element.
// ------------------------------------------------------------------
struct unique
{
template <class T>
auto make_idx_less_fn(const T& data, size_t chunk_sz) const
{
return [&data, chunk_sz](auto idx1, auto idx2) {
return std::lexicographical_compare(data.begin() + idx1,
data.begin() + idx1 + chunk_sz,
data.begin() + idx2,
data.begin() + idx2 + chunk_sz);
};
}
// CASE SORTED:
//
// To process into a sorted unique series of elements/chunks:
// Chunk size == 1 means a simple element; >1 means a flat representation.
// Steps: first go through the input elements/chunks for uniqueness.
// At the end of this processing, per the sorted sequence of unique elements:
// update/create data structures: y, y_indices, x_rev_indices, y_count
//
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// OUTPUT(s): indices..
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [1, 2, 3, 4] --- the unique output is constructed from x[y_indices[...]]
template <class T>
auto sorted_uniq_indices(const T& input_data, size_t chunk_sz) const
{
struct y_info
{
size_t y_idx;
size_t x_idx;
size_t ct = 0;
};
auto idx_less_fn = make_idx_less_fn(input_data, chunk_sz);
std::map<size_t, y_info, decltype(idx_less_fn)> uniq_val_map(idx_less_fn);
std::tuple<std::vector<std::size_t>, std::vector<std::size_t>, std::vector<std::size_t>> rv;
auto& [y_indices, x_rev_indices, y_count] = rv;
// go through all the elements and find the unique elements..
size_t count_x = input_data.size();
for(size_t f_idx = 0, x_idx = 0; f_idx < count_x; f_idx += chunk_sz, x_idx++)
{
y_info entry = {.y_idx = uniq_val_map.size(), .x_idx = x_idx};
auto [itr, added_new] = uniq_val_map.insert({f_idx, entry});
itr->second.ct++;
x_rev_indices.push_back(itr->second.y_idx);
}
std::vector<std::size_t> y2x_indices(uniq_val_map.size());
y_indices.resize(uniq_val_map.size());
y_count.resize(uniq_val_map.size());
size_t idx = 0;
// the unique elements are now sorted:
// post-processing for all the return indices.
for(const auto& v : uniq_val_map)
{
y2x_indices[v.second.y_idx] = idx;
y_indices[idx] = v.second.x_idx;
y_count[idx] = v.second.ct;
idx++;
}
// update x_rev_indices as per the sorted order of y_indices
for(auto& i : x_rev_indices)
i = y2x_indices[i];
return rv;
}
// CASE UNSORTED:
//
// To process into an un-sorted unique series of elements/chunks:
// For chunk size = 1 is a simple element, else use a flat representation of a tensor obj
// Go through the input elements/chunks one by one with inline processing of indices..
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// OUTPUT(s): indices..
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [2, 1, 3, 4] --- the unique output is constructed from x[y_indices[...]]
// Output data structures: y_indices, x_rev_indices, y_count are processed inline.
template <class T>
auto unsorted_uniq_indices(const T& input_data, size_t chunk_sz) const
{
auto idx_less_fn = make_idx_less_fn(input_data, chunk_sz);
std::map<size_t, size_t, decltype(idx_less_fn)> uniq_val_map(idx_less_fn);
// rv is used for NVRO below..
std::tuple<std::vector<std::size_t>, std::vector<std::size_t>, std::vector<std::size_t>> rv;
auto& [y_indices, x_rev_indices, y_count] = rv;
// go through all the elements and add the unique elements into the map..
// inline processing for outputs: y_indices, x_rev_indices, y_count
size_t count_x = input_data.size();
for(size_t f_idx = 0; f_idx < count_x; f_idx += chunk_sz)
{
auto [itr, added_new] = uniq_val_map.insert({f_idx, y_indices.size()});
if(added_new)
{
y_count.push_back(0);
y_indices.push_back(x_rev_indices.size());
}
y_count[itr->second]++;
x_rev_indices.push_back(itr->second);
}
return rv;
}
// Axis. Default: none. Range: [-rank, rank-1]
std::optional<int64_t> axis;
// Sorted, Default: 1= sorted. 0 = unsorted.
bool sorted = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.sorted, "sorted"));
}
std::string name() const { return "unique"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto& sh_x = inputs[0];
auto lens_x = sh_x.lens();
size_t dim_x = sh_x.ndim();
size_t max_uniq_ct = sh_x.elements();
std::vector<shape::dynamic_dimension> d_out;
if(axis)
{
int64_t t_axis = migraphx::tune_axis(dim_x, *axis, name());
if(t_axis != 0)
MIGRAPHX_THROW("Unique: Only supports axis = 0 or None");
d_out = sh_x.to_dynamic().dyn_dims();
// only axis = 0 is supported:
max_uniq_ct = lens_x[0];
// min = 1 unique element; max = full dimension along axis 0
d_out[0] = {1, max_uniq_ct};
}
else
{
d_out.push_back({1, max_uniq_ct});
}
shape sh_y = {sh_x.type(), d_out};
// The three outputted Indices are just 1-D:
shape sh_idx{shape::int64_type, {d_out[0]}};
return {{sh_y, sh_idx, sh_idx, sh_idx}};
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
auto sh_x = args.front().get_shape();
auto lens_x = sh_x.lens();
shape output_shape = dyn_out.computed_shape;
auto vec_ss = output_shape.sub_shapes();
auto ct_x = sh_x.elements();
shape sh_y = {vec_ss[0].type(), {ct_x}};
shape sh_idx = {vec_ss[1].type(), {ct_x}};
shape sh_x_idx = {vec_ss[1].type(), {ct_x}};
argument res_y{sh_y};
argument res_y_idx{sh_idx};
argument res_x_rev_idx{sh_idx};
argument res_y_ct_idx{sh_idx};
std::vector<size_t> out_y_idx;
std::vector<size_t> out_x_rev_idx;
std::vector<size_t> out_y_ct;
// If axis is not none, for >1D tensors, we have to consider
// then, the uniqueness of chunks of sub-tensors: a subsequence of built-ins..
// For a built-in type, chunk_sz is of course = 1
size_t chunk_sz = 1;
if(axis)
chunk_sz = ct_x / lens_x[0]; // axis = 0 is supported.
visit_all(args.front(), res_y)([&](auto x, auto y_flat) {
using o_type = typename decltype(x)::value_type;
std::vector<o_type> x_in(x.begin(), x.end());
std::tie(out_y_idx, out_x_rev_idx, out_y_ct) =
sorted ? sorted_uniq_indices(x_in, chunk_sz)
: unsorted_uniq_indices(x_in, chunk_sz);
const auto uniq_ct = out_y_idx.size();
// construct y from x[indices] in flattened form
// later we reshape y to the final shape..
auto y_dst = y_flat.begin();
for(size_t idx = 0; idx < uniq_ct; idx++)
y_dst = copy_n(x_in.begin() + out_y_idx[idx] * chunk_sz, chunk_sz, y_dst);
std::vector<size_t> lens_y;
// if axis is specified:
// the output shape keeps the n-1 dimensions of x
if(axis)
{
lens_y = lens_x;
lens_y[0] = uniq_ct;
}
else
{
lens_y = {uniq_ct};
}
sh_y = {sh_y.type(), lens_y};
sh_idx = {sh_idx.type(), {uniq_ct}};
});
visit_all(res_y_idx, res_x_rev_idx, res_y_ct_idx)(
[&](auto y_indices, auto x_rev_indices, auto y_count) {
std::copy(out_y_idx.begin(), out_y_idx.end(), y_indices.begin());
std::copy(out_x_rev_idx.begin(), out_x_rev_idx.end(), x_rev_indices.begin());
std::copy(out_y_ct.begin(), out_y_ct.end(), y_count.begin());
sh_x_idx = {sh_idx.type(), {out_x_rev_idx.size()}};
});
return {{res_y.reshape(sh_y),
res_y_idx.reshape(sh_idx),
res_x_rev_idx.reshape(sh_x_idx),
res_y_ct_idx.reshape(sh_idx)}};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment