Commit a98d86d9 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'rocblas_fp8' into rocblas_mlir_fp8

parents a3d4b013 7e80f627
......@@ -28,7 +28,14 @@ MACRO_EXPANSION = YES
OUTPUT_DIRECTORY = docBin
PREDEFINED = DOXYGEN
PREDEFINED = \
DOXYGEN \
MIGRAPHX_EXPORT= \
MIGRAPHX_API_EXPORT= \
MIGRAPHX_GPU_EXPORT= \
MIGRAPHX_CPU_EXPORT= \
MIGRAPHX_ONNX_EXPORT= \
MIGRAPHX_TF_EXPORT= \
PROJECT_NAME = MIGraphX
......
......@@ -5,26 +5,36 @@ shape
-----
.. doxygenstruct:: migraphx::internal::shape
:members:
:undoc-members:
literal
-------
.. doxygenstruct:: migraphx::internal::literal
:members:
:undoc-members:
argument
--------
.. doxygenstruct:: migraphx::internal::argument
:members:
:undoc-members:
raw_data
--------
.. doxygenstruct:: migraphx::internal::raw_data
:members:
:undoc-members:
.. doxygenfunction:: migraphx::internal::visit_all
.. doxygenfunction:: template<class T, class ...Ts> auto migraphx::internal::visit_all(T &&x, Ts&&... xs)
tensor_view
-----------
.. doxygenstruct:: migraphx::internal::tensor_view
:members:
:undoc-members:
......@@ -18,8 +18,8 @@ Directions for building MIGraphX from source can be found in the main README fil
Adding Two Literals
--------------------
A program is a collection of modules, which are collections of instructions to be executed when calling `eval <migraphx::program::eval>`.
Each instruction has an associated `operation <migraphx::operation>` which represents the computation to be performed by the instruction.
A program is a collection of modules, which are collections of instructions to be executed when calling :cpp:any:`eval <migraphx::internal::program::eval>`.
Each instruction has an associated :cpp:any:`operation <migraphx::internal::operation>` which represents the computation to be performed by the instruction.
We start with a snippet of the simple ``add_two_literals()`` function::
......@@ -41,14 +41,14 @@ We start with a snippet of the simple ``add_two_literals()`` function::
auto result = p.eval({}).back();
std::cout << "add_two_literals: 1 + 2 = " << result << "\n";
We start by creating a simple ``migraphx::program`` object and then getting a pointer to the main module of it.
We start by creating a simple :cpp:any:`migraphx::program <migraphx::internal::program>` object and then getting a pointer to the main module of it.
The program is a collection of ``modules`` that start executing from the main module, so instructions are added to the modules rather than directly onto the program object.
We then use the `add_literal <migraphx::program::add_literal>` function to add an instruction that stores the literal number ``1`` while returning an `instruction_ref <migraphx::instruction_ref>`.
The returned `instruction_ref <migraphx::instruction_ref>` can be used in another instruction as an input.
We use the same `add_literal <migraphx::program::add_literal>` function to add a ``2`` to the program.
We then use the :cpp:any:`add_literal <migraphx::internal::program::add_literal>` function to add an instruction that stores the literal number ``1`` while returning an :cpp:any:`instruction_ref <migraphx::internal::instruction_ref>`.
The returned :cpp:any:`instruction_ref <migraphx::internal::instruction_ref>` can be used in another instruction as an input.
We use the same :cpp:any:`add_literal <migraphx::internal::program::add_literal>` function to add a ``2`` to the program.
After creating the literals, we then create the instruction to add the numbers together.
This is done by using the `add_instruction <migraphx::program::add_instruction>` function with the ``"add"`` `operation <migraphx::program::operation>` created by `make_op <migraphx::program::make_op>` along with the previous `add_literal` `instruction_ref <migraphx::instruction_ref>` for the input arguments of the instruction.
Finally, we can run this `program <migraphx::program>` by compiling it for the reference target (CPU) and then running it with `eval <migraphx::program::eval>`
This is done by using the :cpp:any:`add_instruction <migraphx::internal::program::add_instruction>` function with the ``"add"`` :cpp:any:`operation <migraphx::internal::program::operation>` created by :cpp:any:`make_op <migraphx::internal::program::make_op>` along with the previous `add_literal` :cpp:any:`instruction_ref <migraphx::internal::instruction_ref>` for the input arguments of the instruction.
Finally, we can run this :cpp:any:`program <migraphx::internal::program>` by compiling it for the reference target (CPU) and then running it with :cpp:any:`eval <migraphx::internal::program::eval>`
The result is then retreived and printed to the console.
We can compile the program for the GPU as well, but the file will have to be moved to the ``test/gpu/`` directory and the correct target must be included::
......@@ -76,8 +76,8 @@ We can modify the program to take an input parameter ``x``, as seen in the ``add
p.compile(migraphx::ref::target{});
This adds a parameter of type ``int32``, and compiles it for the CPU.
To run the program, we need to pass the parameter as a ``parameter_map`` when we call `eval <migraphx::program::eval>`.
We create the ``parameter_map`` by setting the ``x`` key to an `argument <migraphx::argument>` object with an ``int`` data type::
To run the program, we need to pass the parameter as a ``parameter_map`` when we call :cpp:any:`eval <migraphx::internal::program::eval>`.
We create the ``parameter_map`` by setting the ``x`` key to an :cpp:any:`argument <migraphx::internal::argument>` object with an ``int`` data type::
// create a parameter_map object for passing a value to the "x" parameter
std::vector<int> data = {4};
......@@ -92,7 +92,7 @@ We create the ``parameter_map`` by setting the ``x`` key to an `argument <migrap
Handling Tensor Data
---------------------
In the previous examples we have only been dealing with scalars, but the `shape <migraphx::shape>` class can describe multi-dimensional tensors.
In the previous examples we have only been dealing with scalars, but the :cpp:any:`shape <migraphx::internal::shape>` class can describe multi-dimensional tensors.
For example, we can compute a simple convolution::
migraphx::program p;
......@@ -109,7 +109,7 @@ For example, we can compute a simple convolution::
Here we create two parameters for both the ``input`` and ``weights``.
In the previous examples, we created simple literals, however, most programs will take data from allocated buffers (usually on the GPU).
In this case, we can create `argument <migraphx::argument>` objects directly from the pointers to the buffers::
In this case, we can create :cpp:any:`argument <migraphx::internal::argument>` objects directly from the pointers to the buffers::
// Compile the program
p.compile(migraphx::ref::target{});
......@@ -133,8 +133,8 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro
EXPECT(migraphx::verify::verify_rms_range(results_vector, sol));
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU.
By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target.
An :cpp:any:`argument <migraphx::internal::argument>` can handle memory buffers from either the GPU or the CPU.
By default when running the :cpp:any:`program <migraphx::internal::program>`, buffers are allocated on the corresponding target.
When compiling for the CPU, the buffers by default will be allocated on the CPU.
When compiling for the GPU, the buffers by default will be allocated on the GPU.
With the option ``offload_copy=true`` set while compiling for the GPU, the buffers will be located on the CPU.
......@@ -143,7 +143,7 @@ With the option ``offload_copy=true`` set while compiling for the GPU, the buffe
Importing From ONNX
--------------------
A `program <migraphx::program>` can be built directly from an onnx file using the MIGraphX ONNX parser.
A :cpp:any:`program <migraphx::internal::program>` can be built directly from an onnx file using the MIGraphX ONNX parser.
This makes it easier to use neural networks directly from other frameworks.
In this case, there is an ``parse_onnx`` function::
......
......@@ -5,6 +5,8 @@ operation
---------
.. doxygenstruct:: migraphx::internal::operation
:members:
:undoc-members:
.. doxygenfunction:: migraphx::internal::is_context_free
......@@ -14,3 +16,5 @@ operators
---------
.. doxygennamespace:: migraphx::internal::op
:members:
:undoc-members:
......@@ -5,63 +5,82 @@ pass
----
.. doxygenstruct:: migraphx::internal::pass
:members:
:undoc-members:
dead_code_elimination
---------------------
.. doxygenstruct:: migraphx::internal::dead_code_elimination
:members:
:undoc-members:
eliminate_common_subexpression
------------------------------
.. doxygenstruct:: migraphx::internal::eliminate_common_subexpression
:members:
:undoc-members:
eliminate_concat
----------------
.. doxygenstruct:: migraphx::internal::eliminate_concat
:members:
:undoc-members:
eliminate_contiguous
--------------------
.. doxygenstruct:: migraphx::internal::eliminate_contiguous
:members:
:undoc-members:
eliminate_identity
------------------
.. doxygenstruct:: migraphx::internal::eliminate_identity
:members:
:undoc-members:
eliminate_pad
-------------
.. doxygenstruct:: migraphx::internal::eliminate_pad
:members:
:undoc-members:
propagate_constant
------------------
.. doxygenstruct:: migraphx::internal::propagate_constant
rewrite_batchnorm
-----------------
.. doxygenstruct:: migraphx::internal::rewrite_batchnorm
:members:
:undoc-members:
rewrite_rnn
-----------
.. doxygenstruct:: migraphx::internal::rewrite_rnn
:members:
:undoc-members:
schedule
--------
.. doxygenstruct:: migraphx::internal::schedule
:members:
:undoc-members:
simplify_algebra
----------------
.. doxygenstruct:: migraphx::internal::simplify_algebra
:members:
:undoc-members:
simplify_reshapes
-----------------
.. doxygenstruct:: migraphx::internal::simplify_reshapes
:members:
:undoc-members:
......@@ -5,6 +5,8 @@ instruction
-----------
.. doxygenstruct:: migraphx::internal::instruction
:members:
:undoc-members:
instruction_ref
---------------
......@@ -17,6 +19,8 @@ program
-------
.. doxygenstruct:: migraphx::internal::program
:members:
:undoc-members:
parse_onnx
----------
......
......@@ -5,14 +5,20 @@ target
------
.. doxygenstruct:: migraphx::internal::target
:members:
:undoc-members:
gpu::target
-----------
.. doxygenstruct:: migraphx::internal::gpu::target
:members:
:undoc-members:
cpu::target
-----------
.. doxygenstruct:: migraphx::internal::cpu::target
:members:
:undoc-members:
......@@ -8,45 +8,65 @@ shape
.. doxygenenum:: migraphx_shape_datatype_t
.. doxygenstruct:: migraphx::shape
:members:
:undoc-members:
argument
--------
.. doxygenstruct:: migraphx::argument
:members:
:undoc-members:
target
------
.. doxygenstruct:: migraphx::target
:members:
:undoc-members:
program
-------
.. doxygenstruct:: migraphx::program_parameter_shapes
:members:
:undoc-members:
.. doxygenstruct:: migraphx::program_parameters
:members:
:undoc-members:
.. doxygenstruct:: migraphx_compile_options
:members:
:undoc-members:
.. doxygenstruct:: migraphx::program
:members:
:undoc-members:
quantize
--------
.. doxygenstruct:: migraphx::quantize_op_names
:members:
:undoc-members:
.. doxygenfunction:: migraphx::quantize_fp16(const program&)
.. doxygenfunction:: migraphx::quantize_fp16(const program&, const quantize_op_names&)
.. doxygenstruct:: migraphx::quantize_int8_options
:members:
:undoc-members:
.. doxygenfunction:: migraphx::quantize_int8
.. doxygenfunction::migraphx::quantize_int8
parse_onnx
----------
.. doxygenstruct:: migraphx::onnx_options
:members:
:undoc-members:
.. doxygenfunction:: migraphx::parse_onnx(const char *)
......@@ -63,16 +83,18 @@ parse_onnx
load
----
.. doxygenstruct:: migraphx_file_options
.. doxygenstruct:: migraphx::file_options
:members:
:undoc-members:
.. doxygenfunction:: migraphx::load(const char *)
.. doxygenfunction:: migraphx::load(const char *, migraphx_file_options)
.. doxygenfunction:: migraphx::load(const char *, const file_options&)
save
----
.. doxygenfunction:: migraphx::save(const program&, const char *)
.. doxygenfunction:: migraphx::save(const program&, const char *, migraphx_file_options)
.. doxygenfunction:: migraphx::save(const program&, const char *, const file_options&)
......@@ -49,7 +49,6 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_contiguous.cpp
eliminate_data_type.cpp
eliminate_fp8.cpp
eliminate_identity.cpp
eliminate_pad.cpp
env.cpp
......
......@@ -31,6 +31,72 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void insert_convert_to_supported_type(module& m,
instruction_ref ins,
migraphx::shape::type_t target_type,
std::set<migraphx::shape::type_t> unsupported_types)
{
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) {
if(contains(unsupported_types, i->get_shape().type()))
{
return m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
}
else
{
return i;
}
});
// if no change
if(inputs == ins->inputs())
return;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, inputs);
if(orig_type == shape::tuple_type)
{
auto orig_outs = ins->outputs();
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
return out_ins->name() == "get_tuple_elem";
}))
MIGRAPHX_THROW(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction");
std::transform(
orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
auto orig_out_type = out_ins->get_shape().type();
if(contains(unsupported_types, orig_out_type))
{
auto gte_convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins);
return m.replace_instruction(out_ins, gte_convert);
}
else
{
return m.replace_instruction(out_ins, gte_ins);
}
});
}
else
{
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
void eliminate_data_type::apply(module& m) const
{
static const std::vector<std::string> skip_op_names = {"convert",
......@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
if(unsupported_types.empty())
return;
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
continue;
if(contains(skip_op_names, ins->name()))
continue;
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0)
return i;
return m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), i);
});
if(inputs == ins->inputs())
if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name()))
continue;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto old_type = ins->get_shape().type();
auto out = m.insert_instruction(ins, op, inputs);
auto convert =
m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out);
m.replace_instruction(ins, convert);
if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name()))
insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
}
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iterator>
#include <utility>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_fp8::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(not contains(op_names, ins->name()))
continue;
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> orig_inputs = ins->inputs();
std::vector<instruction_ref> new_inputs;
std::transform(orig_inputs.begin(),
orig_inputs.end(),
std::back_inserter(new_inputs),
[&](const auto& i) {
return m.insert_instruction(
ins,
migraphx::make_op(
"convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
});
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, new_inputs);
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -40,8 +40,9 @@ struct module;
*/
struct MIGRAPHX_EXPORT eliminate_data_type
{
std::set<shape::type_t> types;
std::set<shape::type_t> unsupported_types;
shape::type_t target_type;
std::set<std::string> unsupported_ops = {"all"};
std::string name() const { return "eliminate_data_type"; }
void apply(module& m) const;
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <set>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
This will insert convert operators for the operators that are not implemented for FP8 dtypes
*/
struct MIGRAPHX_EXPORT eliminate_fp8
{
// TODO: Add all device ops as a later PR and add tests for those.
std::set<std::string> op_names;
shape::type_t target_type = migraphx::shape::float_type;
std::string name() const { return "eliminate_fp8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
This diff is collapsed.
......@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::SPARSE_TENSOR:
case onnx::AttributeProto::SPARSE_TENSORS:
case onnx::AttributeProto::TYPE_PROTOS:
case onnx::AttributeProto::TYPE_PROTO:
case onnx::AttributeProto::GRAPHS: return {};
}
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
......@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT8E4M3FNUZ: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
std::transform(data_int32.begin(),
data_int32.end(),
std::back_inserter(data_fp8),
[](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; });
return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8);
}
case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64:
......@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type;
case 14:
case 15:
case 16:
case 17:
case 19:
case 20:
default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
......
......@@ -126,7 +126,6 @@ add_library(migraphx_gpu
fuse_ck.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
gemm_impl.cpp
hip.cpp
kernel.cpp
......@@ -140,7 +139,6 @@ add_library(migraphx_gpu
nonzero.cpp
pack_args.cpp
prefuse_ops.cpp
pad.cpp
perfdb.cpp
pooling.cpp
reverse.cpp
......@@ -168,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops(hip_
argmax
argmin
gather
logsoftmax
loop
multinomial
nonzero
pad
prefix_scan_sum
reverse
scatter
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis)
{
const auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
auto axis_dim_size = lens[axis];
lens[axis] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input_v) {
hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements, 256)([=](auto i) __device__ {
auto idx = out_comp.multi(i);
auto in_index = indices_ptr[idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis] = in_index;
output_ptr[i] = input[idx];
});
});
});
});
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{
std::size_t nelements = arg1.get_shape().elements();
hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
using hip_index = typename decltype(output)::hip_index;
type device_val = pad_clamp<host_type<type>>(value);
gs_launch(stream, result.get_shape().elements())(
[=](auto i) __device__ { output.data()[i] = device_val; });
hip_index offsets;
std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
gs_launch(stream, nelements)([=](auto i) __device__ {
auto idx = input.get_shape().multi(i);
for(std::size_t j = 0; j < offsets.size(); j++)
{
idx[j] += offsets[j];
}
output[idx] = input.data()[i];
});
});
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -114,10 +114,7 @@ struct mlir_op
}
if(ins->name() == "@return")
{
auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
......@@ -139,9 +136,16 @@ get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(
contains({"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
upper_input->name()))
while(contains({"slice",
"transpose",
"multibroadcast",
"broadcast",
"contiguous",
"reshape",
"squeeze",
"flatten",
"unsqueeze"},
upper_input->name()))
{
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/gather.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_gather::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment