Commit 538dbd75 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into resize_op

parents c7161d99 e3e00547
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the 'Software'), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from argparse import ArgumentParser
from diffusers import EulerDiscreteScheduler
from transformers import CLIPTokenizer
from PIL import Image
import migraphx as mgx
import numpy as np
import os
import torch
import time
from functools import wraps
# measurement helper
def measure(fn):
@wraps(fn)
def measure_ms(*args, **kwargs):
start_time = time.perf_counter_ns()
result = fn(*args, **kwargs)
end_time = time.perf_counter_ns()
print(f"Elapsed time: {(end_time - start_time) * 1e-6:.4f} ms\n")
return result
return measure_ms
def get_args():
parser = ArgumentParser()
parser.add_argument(
"-s",
"--seed",
type=int,
default=42,
help="Random seed",
)
parser.add_argument(
"-t",
"--steps",
type=int,
default=20,
help="Number of steps",
)
parser.add_argument(
"-p",
"--prompt",
type=str,
required=True,
help="Prompt",
)
parser.add_argument(
"-n",
"--negative-prompt",
type=str,
default="",
help="Negative prompt",
)
parser.add_argument(
"--scale",
type=float,
default=7.0,
help="Guidance scale",
)
parser.add_argument(
"-o",
"--output",
type=str,
default=None,
help="Output name",
)
return parser.parse_args()
class StableDiffusionMGX():
def __init__(self):
model_id = "stabilityai/stable-diffusion-2-1"
print(f"Using {model_id}")
print("Creating EulerDiscreteScheduler scheduler")
self.scheduler = EulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler")
print("Creating CLIPTokenizer tokenizer...")
self.tokenizer = CLIPTokenizer.from_pretrained(model_id,
subfolder="tokenizer")
print("Load models...")
self.vae = StableDiffusionMGX.load_mgx_model(
"vae_decoder", {"latent_sample": [1, 4, 64, 64]})
self.text_encoder = StableDiffusionMGX.load_mgx_model(
"text_encoder", {"input_ids": [1, 77]})
self.unet = StableDiffusionMGX.load_mgx_model(
"unet", {
"sample": [1, 4, 64, 64],
"encoder_hidden_states": [1, 77, 1024],
"timestep": [1],
})
def run(self, prompt, negative_prompt, steps, seed, scale):
# need to set this for each run
self.scheduler.set_timesteps(steps)
print("Tokenizing prompt...")
text_input = self.tokenize(prompt)
print("Creating text embeddings for prompt...")
text_embeddings = self.get_embeddings(text_input)
print("Tokenizing negative prompt...")
uncond_input = self.tokenize(negative_prompt)
print("Creating text embeddings for negative prompt...")
uncond_embeddings = self.get_embeddings(uncond_input)
print(
f"Creating random input data ({1}x{4}x{64}x{64}) (latents) with seed={seed}..."
)
latents = torch.randn((1, 4, 64, 64),
generator=torch.manual_seed(seed))
print("Apply initial noise sigma\n")
latents = latents * self.scheduler.init_noise_sigma
print("Running denoising loop...")
for step, t in enumerate(self.scheduler.timesteps):
print(f"#{step}/{len(self.scheduler.timesteps)} step")
latents = self.denoise_step(text_embeddings, uncond_embeddings,
latents, t, scale)
print("Scale denoised result...")
latents = 1 / 0.18215 * latents
print("Decode denoised result...")
image = self.decode(latents)
return image
@staticmethod
@measure
def load_mgx_model(name, shapes):
file = f"models/sd21-onnx/{name}/model"
print(f"Loading {name} model from {file}")
if os.path.isfile(f"{file}.mxr"):
print("Found mxr, loading it...")
model = mgx.load(f"{file}.mxr", format="msgpack")
elif os.path.isfile(f"{file}.onnx"):
print("Parsing from onnx file...")
model = mgx.parse_onnx(f"{file}.onnx", map_input_dims=shapes)
model.compile(mgx.get_target("gpu"))
print(f"Saving {name} model to mxr file...")
mgx.save(model, f"{file}.mxr", format="msgpack")
else:
print(f"No {name} model found. Please download it and re-try.")
os.exit(1)
return model
@measure
def tokenize(self, input):
return self.tokenizer([input],
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np")
@measure
def get_embeddings(self, input):
return np.array(
self.text_encoder.run(
{"input_ids":
input.input_ids.astype(np.int32)})[0]).astype(np.float32)
@staticmethod
def convert_to_rgb_image(image):
image = np.clip(image / 2 + 0.5, 0, 1)
image = np.transpose(image, (0, 2, 3, 1))
images = (image * 255).round().astype("uint8")
return Image.fromarray(images[0])
@staticmethod
def save_image(pil_image, filename="output.png"):
pil_image.save(filename)
@measure
def denoise_step(self, text_embeddings, uncond_embeddings, latents, t,
scale):
sample = self.scheduler.scale_model_input(latents,
t).numpy().astype(np.float32)
timestep = np.atleast_1d(t.numpy().astype(
np.int64)) # convert 0D -> 1D
noise_pred_uncond = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": uncond_embeddings,
"timestep": timestep
})[0])
noise_pred_text = np.array(
self.unet.run({
"sample": sample,
"encoder_hidden_states": text_embeddings,
"timestep": timestep
})[0])
# perform guidance
noise_pred = noise_pred_uncond + scale * (noise_pred_text -
noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
return self.scheduler.step(torch.from_numpy(noise_pred), t,
latents).prev_sample
@measure
def decode(self, latents):
return np.array(
self.vae.run({"latent_sample":
latents.numpy().astype(np.float32)})[0])
if __name__ == "__main__":
args = get_args()
sd = StableDiffusionMGX()
result = sd.run(args.prompt, args.negative_prompt, args.steps, args.seed,
args.scale)
print("Convert result to rgb image...")
image = StableDiffusionMGX.convert_to_rgb_image(result)
filename = args.output if args.output else f"output_s{args.seed}_t{args.steps}.png"
StableDiffusionMGX.save_image(image, args.output)
print(f"Image saved to {filename}")
...@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build ...@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@13f6c2a69cfe80a575c6b241ec7353d1e953cb12 -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@a6880f1e6daec99876cd6a4820fbc69c57216401 -DBUILD_FAT_LIBROCKCOMPILER=On
...@@ -28,6 +28,7 @@ include(ROCMInstallTargets) ...@@ -28,6 +28,7 @@ include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers) include(ROCMPackageConfigHelpers)
include(RegisterOp) include(RegisterOp)
include(CheckCXXLinkerFlag) include(CheckCXXLinkerFlag)
include(CheckCXXSourceCompiles)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp adjust_allocation.cpp
...@@ -221,6 +222,8 @@ register_migraphx_ops( ...@@ -221,6 +222,8 @@ register_migraphx_ops(
scatternd_add scatternd_add
scatternd_mul scatternd_mul
scatternd_none scatternd_none
scatternd_max
scatternd_min
select_module select_module
sigmoid sigmoid
sign sign
...@@ -239,6 +242,7 @@ register_migraphx_ops( ...@@ -239,6 +242,7 @@ register_migraphx_ops(
transpose transpose
unary_not unary_not
undefined undefined
unique
unknown unknown
unsqueeze unsqueeze
where where
...@@ -264,6 +268,51 @@ endif() ...@@ -264,6 +268,51 @@ endif()
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_link_libraries(migraphx PUBLIC Threads::Threads) target_link_libraries(migraphx PUBLIC Threads::Threads)
function(check_execution_par RESULT)
set(CMAKE_REQUIRED_LIBRARIES ${ARGN})
set(CMAKE_REQUIRED_FLAGS)
if(NOT MSVC)
set(CMAKE_REQUIRED_FLAGS "-std=c++17")
endif()
string(MD5 _flags_hash "${CMAKE_REQUIRED_FLAGS} ${CMAKE_REQUIRED_LIBRARIES}")
set(_source "
#include <execution>
int main() {
int* i = nullptr;
std::sort(std::execution::par, i, i);
}
")
check_cxx_source_compiles("${_source}" _has_execution_${_flags_hash})
set(${RESULT} ${_has_execution_${_flags_hash}} PARENT_SCOPE)
endfunction()
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT Off)
find_package(TBB QUIET)
if(TBB_FOUND)
check_execution_par(TBB_HAS_EXECUTION_PAR TBB::tbb)
if(TBB_HAS_EXECUTION_PAR)
list(APPEND PACKAGE_DEPENDS PACKAGE TBB)
target_link_libraries(migraphx PUBLIC TBB::tbb)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
message(STATUS "Using TBB for parallel execution")
endif()
else()
check_execution_par(HAS_EXECUTION_PAR)
if(HAS_EXECUTION_PAR)
set(MIGRAPHX_HAS_EXECUTORS_DEFAULT On)
endif()
endif()
option(MIGRAPHX_HAS_EXECUTORS "C++ supports parallel executors" ${MIGRAPHX_HAS_EXECUTORS_DEFAULT})
if(MIGRAPHX_HAS_EXECUTORS)
message("Parallel STL enabled")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_HAS_EXECUTORS=1)
else()
message("Parallel STL disabled")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_HAS_EXECUTORS=0)
endif()
find_package(nlohmann_json 3.8.0 REQUIRED) find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json) target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
migraphx_generate_export_header(migraphx) migraphx_generate_export_header(migraphx)
......
...@@ -21,10 +21,13 @@ ...@@ -21,10 +21,13 @@
* ************************************************************************ */ * ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP #define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#include <type_traits>
#if defined(__GNUC__) && !defined(__clang__) #if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif #endif
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) // NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
...@@ -32,7 +35,10 @@ ...@@ -32,7 +35,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <typename To, typename From> template <typename To,
typename From,
MIGRAPHX_REQUIRES(std::is_trivially_copyable<To>{} and
std::is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept inline constexpr To bit_cast(From fr) noexcept
{ {
static_assert(sizeof(To) == sizeof(From)); static_assert(sizeof(To) == sizeof(From));
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -95,7 +96,7 @@ struct binary : op_name<Derived> ...@@ -95,7 +96,7 @@ struct binary : op_name<Derived>
{ {
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), par_transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
output.begin(), output.begin(),
......
...@@ -70,7 +70,8 @@ struct pooling ...@@ -70,7 +70,8 @@ struct pooling
// 2 smaller than the input tensor rank (NCHW layout) // 2 smaller than the input tensor rank (NCHW layout)
std::vector<std::size_t> lengths = {1, 1}; std::vector<std::size_t> lengths = {1, 1};
// Dilations are not supported at this time. // Spacing between the elements of the pooling kernel. Must be the same ndim as lengths.
std::vector<std::size_t> dilations = {1, 1};
// ceiling mode is a flag affecting output size // ceiling mode is a flag affecting output size
// or equivalently, placements of the pooling kernel. // or equivalently, placements of the pooling kernel.
...@@ -99,6 +100,7 @@ struct pooling ...@@ -99,6 +100,7 @@ struct pooling
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.lengths, "lengths"), f(self.lengths, "lengths"),
f(self.dilations, "dilations"),
f(self.ceil_mode, "ceil_mode"), f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order"), f(self.lp_order, "lp_order"),
f(self.dyn_global, "dyn_global")); f(self.dyn_global, "dyn_global"));
...@@ -112,14 +114,17 @@ struct pooling ...@@ -112,14 +114,17 @@ struct pooling
return; return;
if((padding_mode != default_ and padding.size() != stride.size() and if((padding_mode != default_ and padding.size() != stride.size() and
(padding.size()) != stride.size() * 2) or (padding.size()) != stride.size() * 2) or
stride.size() != lengths.size()) stride.size() != lengths.size() or dilations.size() != lengths.size())
{ {
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
} }
if(std::any_of(lengths.begin(), lengths.end(), [&](auto i) { return (i == 0); }) or
std::any_of(stride.begin(), stride.end(), [&](auto i) { return (i == 0); })) const auto is_zero = [](auto el) { return el == 0; };
if(std::any_of(lengths.begin(), lengths.end(), is_zero) or
std::any_of(stride.begin(), stride.end(), is_zero) or
std::any_of(dilations.begin(), dilations.end(), is_zero))
{ {
MIGRAPHX_THROW("POOLING: size 0 pooling kernel or stride"); MIGRAPHX_THROW("POOLING: size 0 pooling kernel or stride or dilations");
} }
// TODO: update lowering to run the reference // TODO: update lowering to run the reference
...@@ -142,6 +147,11 @@ struct pooling ...@@ -142,6 +147,11 @@ struct pooling
value attributes() const { return {{"normalize_padding", "padding"}}; } value attributes() const { return {{"normalize_padding", "padding"}}; }
inline std::size_t dilate_dim(std::size_t dim, std::size_t dilation) const
{
return 1 + dilation * (dim - 1);
}
std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens, std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens,
std::size_t kdims) const std::size_t kdims) const
{ {
...@@ -151,8 +161,9 @@ struct pooling ...@@ -151,8 +161,9 @@ struct pooling
std::size_t padding_factor = 2 * padding[i]; std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims) if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims]; padding_factor = padding[i] + padding[i + kdims];
std::size_t dilated_length = dilate_dim(lengths[i], dilations[i]);
std::size_t dim_size; std::size_t dim_size;
if(input_lens[i + 2] + padding_factor < lengths[i]) if(input_lens[i + 2] + padding_factor < dilated_length)
{ {
if(padding_mode == default_) if(padding_mode == default_)
MIGRAPHX_THROW("POOLING: not enough padding for the given kernel size"); MIGRAPHX_THROW("POOLING: not enough padding for the given kernel size");
...@@ -162,7 +173,7 @@ struct pooling ...@@ -162,7 +173,7 @@ struct pooling
} }
else else
{ {
dim_size = input_lens[i + 2] + padding_factor - lengths[i]; dim_size = input_lens[i + 2] + padding_factor - dilated_length;
} }
std::size_t len = std::size_t len =
(ceil_mode) (ceil_mode)
...@@ -331,6 +342,7 @@ struct pooling ...@@ -331,6 +342,7 @@ struct pooling
int start = static_cast<int>(idx_o[dim] * stride[d_2]) - int start = static_cast<int>(idx_o[dim] * stride[d_2]) -
static_cast<int>(padding_vals[d_2]); static_cast<int>(padding_vals[d_2]);
int end; int end;
std::size_t dilated_kernel_dim = dilate_dim(kernel_dims[d_2], dilations[d_2]);
// NOLINT // NOLINT
if(count_include_pad and ceil_mode and (mode != pooling_mode::max)) if(count_include_pad and ceil_mode and (mode != pooling_mode::max))
{ {
...@@ -340,15 +352,14 @@ struct pooling ...@@ -340,15 +352,14 @@ struct pooling
// padding. Clip out-of-bounds indexes but not padding. // padding. Clip out-of-bounds indexes but not padding.
// Check if this kernel extends beyond the padding at end of dimension // Check if this kernel extends beyond the padding at end of dimension
end = std::min(start + kernel_dims[d_2], end = std::min(start + dilated_kernel_dim,
in_lens[dim] + static_cast<int>(padding_vals[d_2])); in_lens[dim] + static_cast<int>(padding_vals[d_2]));
} }
else else
{ {
// In non-ceiling mode, when // In non-ceiling mode, when
// count_include_pad is false, or for max pooling, clip off padding. // count_include_pad is false, or for max pooling, clip off padding.
end = std::min(start + kernel_dims[d_2], in_lens[dim]); end = std::min(start + dilated_kernel_dim, in_lens[dim]);
start = std::max(start, 0);
} }
win_start.push_back(start); win_start.push_back(start);
if(end < start) if(end < start)
...@@ -366,6 +377,16 @@ struct pooling ...@@ -366,6 +377,16 @@ struct pooling
// for each element in the window... // for each element in the window...
shape_for_each(win_shape, [&](const auto& idx_w) { shape_for_each(win_shape, [&](const auto& idx_w) {
// Skip elements that belong to the dilated area
for(size_t axis = 0; axis < idx_w.size(); ++axis)
{
if(idx_w[axis] % dilations[axis])
{
pool_size -= 1;
return;
}
}
// the coordinates of this element // the coordinates of this element
auto idx = idx_o; auto idx = idx_o;
...@@ -390,7 +411,15 @@ struct pooling ...@@ -390,7 +411,15 @@ struct pooling
// this is a padding element. Padding locations // this is a padding element. Padding locations
// don't contribute to average or max pooling total but can play in // don't contribute to average or max pooling total but can play in
// lpnorm pooling. // lpnorm pooling.
output_val = op(output_val, 0); if(mode == pooling_mode::lpnorm)
{
output_val = op(output_val, op.template init<Type>());
}
if(mode == pooling_mode::average)
{
// Ignore padding
pool_size -= 1;
}
} }
}); });
output[i] = Type(op.final(output_val, pool_size)); output[i] = Type(op.final(output_val, pool_size));
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,23 +21,26 @@ ...@@ -21,23 +21,26 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP #define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#include <migraphx/argument.hpp> #include <migraphx/op/scatternd_op.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace op {
namespace device {
argument MIGRAPHX_DEVICE_EXPORT struct scatternd_max : scatternd_op<scatternd_max>
gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis); {
scatternd_max() {}
} // namespace device auto reduction() const
} // namespace gpu {
return [](auto& x, const auto& y) { x = std::max(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,25 +21,27 @@ ...@@ -21,25 +21,27 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/gather.hpp> #ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#include <migraphx/gpu/context.hpp> #define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace op {
shape hip_gather::compute_shape(std::vector<shape> inputs) const struct scatternd_min : scatternd_op<scatternd_min>
{ {
inputs.pop_back(); scatternd_min() {}
return op.normalize_compute_shape(inputs);
}
argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const auto reduction() const
{ {
return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis); return [](auto& x, const auto& y) { x = std::min(x, y); };
} }
};
} // namespace gpu } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif
...@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived> ...@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived>
auto k = indices_shape.lens().back(); auto k = indices_shape.lens().back();
auto q = indices_shape.ndim(); auto q = indices_shape.ndim();
auto r = dyn_out.computed_shape.ndim(); auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) { for(auto i = 0u; i < updates_shape.elements(); ++i)
{
auto updates_idx = updates_std.multi(i); auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0); std::vector<std::size_t> indices_idx(q, 0);
std::copy( std::copy(
...@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived> ...@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived>
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]); self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
}); }
}); });
}); });
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -84,7 +85,7 @@ struct unary : op_name<Derived> ...@@ -84,7 +85,7 @@ struct unary : op_name<Derived>
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
std::transform(input.begin(), par_transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tune_axis.hpp>
#include <utility>
#include <map>
#include <limits>
#include <optional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// https://onnx.ai/onnx/operators/onnx__Unique.html
// The Onnx spec refers to numpy specification, used as a reference:
// https://numpy.org/doc/stable/reference/generated/numpy.unique.html
// Input : Given an array of elements : X.
// Output(s) :
// 1. Find the unique elements (Y) of input (X).
//
// There are three outputs in addition to the unique elements in Y:
// 2. the indices of the input array that give the unique values
// 3. the indices of the unique array that reconstruct the input array
// 4. the number of times each unique value comes up in the input array
// Optional Attribute: 'Sorted' = 1 for sorted; = 0 for unsorted.
// Onnx specification makes 'sorted' a default, while Numpy always sorts.
//
// Optional Attribute: 'Axis' is 'None' (default) or a valid int < rank(X).
// Negative values are allowed.
//
// Numpy has the following important note on Axis:
// ------------------------------------------------------------------
// When an axis is specified the subarrays indexed by the axis are
// sorted. This is done by making the specified axis the first
// dimension of the array (move the axis to the first dimension to
// keep the order of the other axes) and then flattening the subarrays
// in C order. The flattened subarrays are then viewed as a structured
// type with each element given a label, with the effect that we end
// up with a 1-D array of structured types that can be treated in the
// same way as any other 1-D array. The result is that the flattened
// subarrays are sorted in lexicographic order starting with the first
// element.
// ------------------------------------------------------------------
struct unique
{
template <class T>
auto make_idx_less_fn(const T& data, size_t chunk_sz) const
{
return [&data, chunk_sz](auto idx1, auto idx2) {
return std::lexicographical_compare(data.begin() + idx1,
data.begin() + idx1 + chunk_sz,
data.begin() + idx2,
data.begin() + idx2 + chunk_sz);
};
}
// CASE SORTED:
//
// To process into a sorted unique series of elements/chunks:
// Chunk size == 1 means a simple element; >1 means a flat representation.
// Steps: first go through the input elements/chunks for uniqueness.
// At the end of this processing, per the sorted sequence of unique elements:
// update/create data structures: y, y_indices, x_rev_indices, y_count
//
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// OUTPUT(s): indices..
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [1, 2, 3, 4] --- the unique output is constructed from x[y_indices[...]]
template <class T>
auto sorted_uniq_indices(const T& input_data, size_t chunk_sz) const
{
struct y_info
{
size_t y_idx;
size_t x_idx;
size_t ct = 0;
};
auto idx_less_fn = make_idx_less_fn(input_data, chunk_sz);
std::map<size_t, y_info, decltype(idx_less_fn)> uniq_val_map(idx_less_fn);
std::tuple<std::vector<std::size_t>, std::vector<std::size_t>, std::vector<std::size_t>> rv;
auto& [y_indices, x_rev_indices, y_count] = rv;
// go through all the elements and find the unique elements..
size_t count_x = input_data.size();
for(size_t f_idx = 0, x_idx = 0; f_idx < count_x; f_idx += chunk_sz, x_idx++)
{
y_info entry = {.y_idx = uniq_val_map.size(), .x_idx = x_idx};
auto [itr, added_new] = uniq_val_map.insert({f_idx, entry});
itr->second.ct++;
x_rev_indices.push_back(itr->second.y_idx);
}
std::vector<std::size_t> y2x_indices(uniq_val_map.size());
y_indices.resize(uniq_val_map.size());
y_count.resize(uniq_val_map.size());
size_t idx = 0;
// the unique elements are now sorted:
// post-processing for all the return indices.
for(const auto& v : uniq_val_map)
{
y2x_indices[v.second.y_idx] = idx;
y_indices[idx] = v.second.x_idx;
y_count[idx] = v.second.ct;
idx++;
}
// update x_rev_indices as per the sorted order of y_indices
for(auto& i : x_rev_indices)
i = y2x_indices[i];
return rv;
}
// CASE UNSORTED:
//
// To process into an un-sorted unique series of elements/chunks:
// For chunk size = 1 is a simple element, else use a flat representation of a tensor obj
// Go through the input elements/chunks one by one with inline processing of indices..
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// OUTPUT(s): indices..
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [2, 1, 3, 4] --- the unique output is constructed from x[y_indices[...]]
// Output data structures: y_indices, x_rev_indices, y_count are processed inline.
template <class T>
auto unsorted_uniq_indices(const T& input_data, size_t chunk_sz) const
{
auto idx_less_fn = make_idx_less_fn(input_data, chunk_sz);
std::map<size_t, size_t, decltype(idx_less_fn)> uniq_val_map(idx_less_fn);
// rv is used for NVRO below..
std::tuple<std::vector<std::size_t>, std::vector<std::size_t>, std::vector<std::size_t>> rv;
auto& [y_indices, x_rev_indices, y_count] = rv;
// go through all the elements and add the unique elements into the map..
// inline processing for outputs: y_indices, x_rev_indices, y_count
size_t count_x = input_data.size();
for(size_t f_idx = 0; f_idx < count_x; f_idx += chunk_sz)
{
auto [itr, added_new] = uniq_val_map.insert({f_idx, y_indices.size()});
if(added_new)
{
y_count.push_back(0);
y_indices.push_back(x_rev_indices.size());
}
y_count[itr->second]++;
x_rev_indices.push_back(itr->second);
}
return rv;
}
// Axis. Default: none. Range: [-rank, rank-1]
std::optional<int64_t> axis;
// Sorted, Default: 1= sorted. 0 = unsorted.
bool sorted = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.sorted, "sorted"));
}
std::string name() const { return "unique"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto& sh_x = inputs[0];
auto lens_x = sh_x.lens();
size_t dim_x = sh_x.ndim();
size_t max_uniq_ct = sh_x.elements();
std::vector<shape::dynamic_dimension> d_out;
if(axis)
{
int64_t t_axis = migraphx::tune_axis(dim_x, *axis, name());
if(t_axis != 0)
MIGRAPHX_THROW("Unique: Only supports axis = 0 or None");
d_out = sh_x.to_dynamic().dyn_dims();
// only axis = 0 is supported:
max_uniq_ct = lens_x[0];
// min = 1 unique element; max = full dimension along axis 0
d_out[0] = {1, max_uniq_ct};
}
else
{
d_out.push_back({1, max_uniq_ct});
}
shape sh_y = {sh_x.type(), d_out};
// The three outputted Indices are just 1-D:
shape sh_idx{shape::int64_type, {d_out[0]}};
return {{sh_y, sh_idx, sh_idx, sh_idx}};
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
auto sh_x = args.front().get_shape();
auto lens_x = sh_x.lens();
shape output_shape = dyn_out.computed_shape;
auto vec_ss = output_shape.sub_shapes();
auto ct_x = sh_x.elements();
shape sh_y = {vec_ss[0].type(), {ct_x}};
shape sh_idx = {vec_ss[1].type(), {ct_x}};
shape sh_x_idx = {vec_ss[1].type(), {ct_x}};
argument res_y{sh_y};
argument res_y_idx{sh_idx};
argument res_x_rev_idx{sh_idx};
argument res_y_ct_idx{sh_idx};
std::vector<size_t> out_y_idx;
std::vector<size_t> out_x_rev_idx;
std::vector<size_t> out_y_ct;
// If axis is not none, for >1D tensors, we have to consider
// then, the uniqueness of chunks of sub-tensors: a subsequence of built-ins..
// For a built-in type, chunk_sz is of course = 1
size_t chunk_sz = 1;
if(axis)
chunk_sz = ct_x / lens_x[0]; // axis = 0 is supported.
visit_all(args.front(), res_y)([&](auto x, auto y_flat) {
using o_type = typename decltype(x)::value_type;
std::vector<o_type> x_in(x.begin(), x.end());
std::tie(out_y_idx, out_x_rev_idx, out_y_ct) =
sorted ? sorted_uniq_indices(x_in, chunk_sz)
: unsorted_uniq_indices(x_in, chunk_sz);
const auto uniq_ct = out_y_idx.size();
// construct y from x[indices] in flattened form
// later we reshape y to the final shape..
auto y_dst = y_flat.begin();
for(size_t idx = 0; idx < uniq_ct; idx++)
y_dst = copy_n(x_in.begin() + out_y_idx[idx] * chunk_sz, chunk_sz, y_dst);
std::vector<size_t> lens_y;
// if axis is specified:
// the output shape keeps the n-1 dimensions of x
if(axis)
{
lens_y = lens_x;
lens_y[0] = uniq_ct;
}
else
{
lens_y = {uniq_ct};
}
sh_y = {sh_y.type(), lens_y};
sh_idx = {sh_idx.type(), {uniq_ct}};
});
visit_all(res_y_idx, res_x_rev_idx, res_y_ct_idx)(
[&](auto y_indices, auto x_rev_indices, auto y_count) {
std::copy(out_y_idx.begin(), out_y_idx.end(), y_indices.begin());
std::copy(out_x_rev_idx.begin(), out_x_rev_idx.end(), x_rev_indices.begin());
std::copy(out_y_ct.begin(), out_y_ct.end(), y_count.begin());
sh_x_idx = {sh_idx.type(), {out_x_rev_idx.size()}};
});
return {{res_y.reshape(sh_y),
res_y_idx.reshape(sh_idx),
res_x_rev_idx.reshape(sh_x_idx),
res_y_ct_idx.reshape(sh_idx)}};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -119,6 +119,8 @@ ...@@ -119,6 +119,8 @@
#include <migraphx/op/scatternd_add.hpp> #include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp> #include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp> #include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/scatternd_max.hpp>
#include <migraphx/op/scatternd_min.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
...@@ -137,6 +139,7 @@ ...@@ -137,6 +139,7 @@
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp> #include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp> #include <migraphx/op/undefined.hpp>
#include <migraphx/op/unique.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp> #include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp> #include <migraphx/op/where.hpp>
......
...@@ -21,46 +21,113 @@ ...@@ -21,46 +21,113 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/shape.hpp> #ifndef MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/argument.hpp> #define MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/clamp.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/device/pad.hpp> #if MIGRAPHX_HAS_EXECUTORS
#include <migraphx/gpu/device/tensor.hpp> #include <execution>
#include <migraphx/gpu/device/launch.hpp> #else
#include <migraphx/float_equal.hpp> #include <migraphx/simple_par_for.hpp>
#endif
#include <algorithm>
#include <mutex>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument namespace detail {
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{
std::size_t nelements = arg1.get_shape().elements();
hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
using hip_index = typename decltype(output)::hip_index;
type device_val = pad_clamp<host_type<type>>(value);
gs_launch(stream, result.get_shape().elements())(
[=](auto i) __device__ { output.data()[i] = device_val; });
hip_index offsets; struct exception_list
std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin()); {
gs_launch(stream, nelements)([=](auto i) __device__ { std::vector<std::exception_ptr> exceptions;
auto idx = input.get_shape().multi(i); std::mutex m;
for(std::size_t j = 0; j < offsets.size(); j++) void add_exception()
{ {
idx[j] += offsets[j]; std::lock_guard<std::mutex> guard(m);
exceptions.push_back(std::current_exception());
} }
output[idx] = input.data()[i]; template <class F>
}); auto collect(F f)
}); {
return result; return [f, this](auto&&... xs) {
try
{
f(std::forward<decltype(xs)>(xs)...);
}
catch(...)
{
this->add_exception();
}
};
}
void throw_if_exception() const
{
if(not exceptions.empty())
std::rethrow_exception(exceptions.front());
}
};
} // namespace detail
template <class InputIt, class OutputIt, class UnaryOperation>
OutputIt par_transform(InputIt first1, InputIt last1, OutputIt d_first, UnaryOperation unary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(std::execution::par, first1, last1, d_first, std::move(unary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = unary_op(first1[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt1, class InputIt2, class OutputIt, class BinaryOperation>
OutputIt par_transform(
InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt d_first, BinaryOperation binary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(
std::execution::par, first1, last1, first2, d_first, std::move(binary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = binary_op(first1[i], first2[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt, class UnaryFunction>
void par_for_each(InputIt first, InputIt last, UnaryFunction f)
{
#if MIGRAPHX_HAS_EXECUTORS
// Propagate the exception
detail::exception_list ex;
std::for_each(std::execution::par, first, last, ex.collect(std::move(f)));
ex.throw_if_exception();
#else
simple_par_for(last - first, [&](auto i) { f(first[i]); });
#endif
}
template <class... Ts>
auto par_copy_if(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::copy_if(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::copy_if(std::forward<Ts>(xs)...);
#endif
}
template <class... Ts>
auto par_sort(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::sort(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::sort(std::forward<Ts>(xs)...);
#endif
} }
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
...@@ -24,93 +24,23 @@ ...@@ -24,93 +24,23 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP #define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread> #include <migraphx/par.hpp>
#include <cmath> #include <migraphx/ranges.hpp>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F> template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f) void par_for(std::size_t n, F f)
{ {
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(), using iterator = basic_iota_iterator<id, std::size_t>;
n / std::max<std::size_t>(1, min_grain)); par_for_each(iterator{0, {}}, iterator{n, {}}, f);
par_for_impl(n, threadsize, f);
} }
template <class F> template <class F>
void par_for(std::size_t n, F f) void par_for(std::size_t n, std::size_t, F f)
{ {
const int min_grain = 8; par_for(n, f);
par_for(n, min_grain, f);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <string> #include <string>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -21,47 +21,99 @@ ...@@ -21,47 +21,99 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/shape.hpp> #ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <migraphx/argument.hpp> #define MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <thread>
#include <migraphx/gpu/device/launch.hpp> #include <cmath>
#include <migraphx/gpu/device/types.hpp> #include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis) struct joinable_thread : std::thread
{ {
const auto& input_shape = arg1.get_shape(); template <class... Xs>
auto lens = input_shape.lens(); joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
auto axis_dim_size = lens[axis]; {
lens[axis] = arg2.get_shape().elements(); }
shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input_v) { joinable_thread& operator=(joinable_thread&& other) = default;
hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) { joinable_thread(joinable_thread&& other) = default;
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); ~joinable_thread()
auto* output_ptr = device_cast(output.data()); {
gs_launch(stream, nelements, 256)([=](auto i) __device__ { if(this->joinable())
auto idx = out_comp.multi(i); this->join();
auto in_index = indices_ptr[idx[axis]]; }
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; };
idx[axis] = in_index;
output_ptr[i] = input[idx]; template <class F>
}); auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
}); {
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void simple_par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
}); });
work += grainsize;
++tid;
return result;
}); });
assert(work >= n);
}
}
return result; template <class F>
void simple_par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
simple_par_for_impl(n, threadsize, f);
}
template <class F>
void simple_par_for(std::size_t n, F f)
{
const int min_grain = 8;
simple_par_for(n, min_grain, f);
} }
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -24,21 +24,21 @@ ...@@ -24,21 +24,21 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP #define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#include <utility>
#include <cstdint>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR") inline int tune_axis(int n_dim, int axis, const std::string& op_name = "OPERATOR")
{ {
if(axis >= n_dim or std::abs(axis) > n_dim) if(axis < 0)
{ axis += n_dim;
if(axis < 0 or axis >= n_dim)
MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range."); MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range.");
}
return (axis < 0) ? axis + n_dim : axis; return axis;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,27 +21,26 @@ ...@@ -21,27 +21,26 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP #include <migraphx/config.hpp>
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP #include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/argument.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace onnx {
namespace device {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values);
argument MIGRAPHX_DEVICE_EXPORT pad(hipStream_t stream, instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0);
argument result,
argument arg1,
float value,
std::vector<std::int64_t> pads);
} // namespace device } // namespace onnx
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
// //
// Copyright (c) ONNX Project Contributors. // SPDX-License-Identifier: Apache-2.0
// Licensed under the MIT license.
syntax = "proto2"; syntax = "proto2";
...@@ -27,13 +27,6 @@ package onnx_for_migraphx; ...@@ -27,13 +27,6 @@ package onnx_for_migraphx;
// Notes // Notes
// //
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
//
// Protobuf compatibility // Protobuf compatibility
// //
// To simplify framework compatibility, ONNX is defined using the subset of protobuf // To simplify framework compatibility, ONNX is defined using the subset of protobuf
...@@ -92,15 +85,28 @@ enum Version { ...@@ -92,15 +85,28 @@ enum Version {
// - Add sparse initializers // - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006; IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on <TBD> // IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and // - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the // mutable variables. Global variables are visible in all graphs of the
// stored models. // stored models.
// - Add message TrainingInfoProto to store initialization // - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto // method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables. // can modify the values of mutable variables.
// - Make inference graph callable from TrainingInfoProto via GraphCall operator. // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION = 0x0000000000000007; IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on TBD
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
} }
// Attributes // Attributes
...@@ -121,6 +127,7 @@ message AttributeProto { ...@@ -121,6 +127,7 @@ message AttributeProto {
TENSOR = 4; TENSOR = 4;
GRAPH = 5; GRAPH = 5;
SPARSE_TENSOR = 11; SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6; FLOATS = 6;
INTS = 7; INTS = 7;
...@@ -128,6 +135,7 @@ message AttributeProto { ...@@ -128,6 +135,7 @@ message AttributeProto {
TENSORS = 9; TENSORS = 9;
GRAPHS = 10; GRAPHS = 10;
SPARSE_TENSORS = 12; SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
} }
// The name field MUST be present for this version of the IR. // The name field MUST be present for this version of the IR.
...@@ -159,6 +167,7 @@ message AttributeProto { ...@@ -159,6 +167,7 @@ message AttributeProto {
optional SparseTensorProto sparse_tensor = 22; // sparse tensor value optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated. // Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph // optional ValueProto v = 12; // value - subsumes everything but graph
optional TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints repeated int64 ints = 8; // list of ints
...@@ -166,6 +175,7 @@ message AttributeProto { ...@@ -166,6 +175,7 @@ message AttributeProto {
repeated TensorProto tensors = 10; // list of tensors repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
} }
// Defines information on value, including the name, the type, and // Defines information on value, including the name, the type, and
...@@ -211,7 +221,7 @@ message NodeProto { ...@@ -211,7 +221,7 @@ message NodeProto {
// TrainingInfoProto stores information for training a model. // TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step // In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model // and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been consumed. // back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data. // Training algorithm improves the model based on input data.
// //
// The semantics of the initialization-step is that the initializers // The semantics of the initialization-step is that the initializers
...@@ -224,8 +234,8 @@ message NodeProto { ...@@ -224,8 +234,8 @@ message NodeProto {
// training algorithm's step. After the execution of a // training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding" // TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains // may be immediately updated. If the targeted training algorithm contains
// consecutive update stages (such as block coordinate descent methods), // consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each stage. // the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto { message TrainingInfoProto {
// This field describes a graph to compute the initial tensors // This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input // upon starting the training process. Initialization graph has no input
...@@ -239,20 +249,38 @@ message TrainingInfoProto { ...@@ -239,20 +249,38 @@ message TrainingInfoProto {
// iteration to zero. // iteration to zero.
// //
// By default, this field is an empty graph and its evaluation does not // By default, this field is an empty graph and its evaluation does not
// produce any output. // produce any output. Thus, no initializer would be changed by default.
optional GraphProto initialization = 1; optional GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs, // This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's // it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this graph contains loss node, gradient node, // initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count, and some calls to the inference // optimizer node, increment of iteration count.
// graph.
// //
// The field algorithm.node is the only place the user can use GraphCall // An execution of the training algorithm step is performed by executing the
// operator. The only callable graph is the one stored in ModelProto.graph. // graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
// //
// By default, this field is an empty graph and its evaluation does not // By default, this field is an empty graph and its evaluation does not
// produce any output. // produce any output. Evaluating the default training step never
// update any initializers.
optional GraphProto algorithm = 2; optional GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to // This field specifies the bindings from the outputs of "initialization" to
...@@ -284,23 +312,16 @@ message TrainingInfoProto { ...@@ -284,23 +312,16 @@ message TrainingInfoProto {
// be multiple key-value pairs in "update_binding". // be multiple key-value pairs in "update_binding".
// //
// The initializers appears as keys in "update_binding" are considered // The initializers appears as keys in "update_binding" are considered
// mutable and globally-visible variables. This implies some behaviors // mutable variables. This implies some behaviors
// as described below. // as described below.
// //
// 1. We have only unique keys in all "update_binding"s so that two global // 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one // variables may not have the same name. This ensures that one
// global variable is assigned up to once. // variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or // 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer". // "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm". // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. If an optional input of a graph is omitted when using GraphCall, the // 4. Mutable variables are initialized to the value specified by the
// global variable with the same name may be used.
// 5. When using GraphCall, the users always can pass values to optional
// inputs of the called graph even if the associated initializers appears
// as keys in "update_binding"s.
// 6. The graphs in TrainingInfoProto's can use global variables as
// their operator inputs.
// 7. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by // corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
// //
...@@ -375,13 +396,31 @@ message ModelProto { ...@@ -375,13 +396,31 @@ message ModelProto {
// //
// If this field is empty, the training behavior of the model is undefined. // If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20; repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
}; };
// StringStringEntryProto follows the pattern for cross-proto-version maps. // StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps // See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto { message StringStringEntryProto {
optional string key = 1; optional string key = 1;
optional string value= 2; optional string value = 2;
}; };
message TensorAnnotation { message TensorAnnotation {
...@@ -409,8 +448,9 @@ message GraphProto { ...@@ -409,8 +448,9 @@ message GraphProto {
optional string name = 2; // namespace Graph optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph. // A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// MAY also appear in the input list. // The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5; repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format. // Initializers (see above) stored in sparse format.
...@@ -433,13 +473,8 @@ message GraphProto { ...@@ -433,13 +473,8 @@ message GraphProto {
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14; repeated TensorAnnotation quantization_annotation = 14;
// DO NOT USE the following fields, they were deprecated from earlier versions. reserved 3, 4, 6 to 9;
// repeated string input = 3; reserved "ir_version", "producer_version", "producer_tag", "domain";
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
} }
// Tensors // Tensors
...@@ -474,6 +509,17 @@ message TensorProto { ...@@ -474,6 +509,17 @@ message TensorProto {
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16; BFLOAT16 = 16;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here. // Future extensions go here.
} }
...@@ -507,11 +553,11 @@ message TensorProto { ...@@ -507,11 +553,11 @@ message TensorProto {
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64. // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true]; repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16 values // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// float16 values must be bit-wise converted to an uint16_t prior // float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer. // to writing to the buffer.
// When this field is present, the data_type field MUST be // When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true]; repeated int32 int32_data = 5 [packed = true];
// For strings. // For strings.
...@@ -589,6 +635,8 @@ message TensorProto { ...@@ -589,6 +635,8 @@ message TensorProto {
message SparseTensorProto { message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ]. // The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors. // The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
optional TensorProto values = 1; optional TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats. // The indices of the non-default values, which may be stored in one of two formats.
...@@ -619,7 +667,7 @@ message TensorShapeProto { ...@@ -619,7 +667,7 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor // Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure // dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor. // that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations. // for pre-defined dimension denotations.
optional string denotation = 3; optional string denotation = 3;
}; };
...@@ -656,6 +704,23 @@ message TypeProto { ...@@ -656,6 +704,23 @@ message TypeProto {
optional TypeProto value_type = 2; optional TypeProto value_type = 2;
}; };
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
optional TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
oneof value { oneof value {
// The type of a tensor. // The type of a tensor.
...@@ -672,11 +737,18 @@ message TypeProto { ...@@ -672,11 +737,18 @@ message TypeProto {
// The type of a map. // The type of a map.
Map map_type = 5; Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
} }
// An optional denotation can be used to denote the whole // An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is // type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations. // for pre-defined type denotations.
optional string denotation = 6; optional string denotation = 6;
} }
...@@ -696,7 +768,67 @@ message OperatorSetIdProto { ...@@ -696,7 +768,67 @@ message OperatorSetIdProto {
optional int64 version = 2; optional int64 version = 2;
} }
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
optional string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
optional string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated OperatorSetIdProto opset_import = 9;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional string domain = 10;
}
// For using protobuf-lite // For using protobuf-lite
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
\ No newline at end of file
...@@ -34,7 +34,9 @@ ...@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case onnx::AttributeProto::TENSORS: case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::SPARSE_TENSOR: case onnx::AttributeProto::SPARSE_TENSOR:
case onnx::AttributeProto::SPARSE_TENSORS: case onnx::AttributeProto::SPARSE_TENSORS:
case onnx::AttributeProto::TYPE_PROTOS:
case onnx::AttributeProto::TYPE_PROTO:
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type())); MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
...@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data()); return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT8E4M3FNUZ: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
std::transform(data_int32.begin(),
data_int32.end(),
std::back_inserter(data_fp8),
[](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; });
return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8);
}
case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING: case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
...@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype) ...@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type;
case 14:
case 15:
case 16:
case 17:
case 19:
case 20:
default: { default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment