Commit 6f37331e authored by Paul's avatar Paul
Browse files

Merge commit '8c73c72e' into navi-reduce

parents d00fdf6e 8c73c72e
...@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build ...@@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@9e66e8050209f03349a41b6b497f0da2b285a53b -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@a6880f1e6daec99876cd6a4820fbc69c57216401 -DBUILD_FAT_LIBROCKCOMPILER=On
...@@ -31,6 +31,72 @@ ...@@ -31,6 +31,72 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void insert_convert_to_supported_type(module& m,
instruction_ref ins,
migraphx::shape::type_t target_type,
std::set<migraphx::shape::type_t> unsupported_types)
{
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) {
if(contains(unsupported_types, i->get_shape().type()))
{
return m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
}
else
{
return i;
}
});
// if no change
if(inputs == ins->inputs())
return;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, inputs);
if(orig_type == shape::tuple_type)
{
auto orig_outs = ins->outputs();
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
return out_ins->name() == "get_tuple_elem";
}))
MIGRAPHX_THROW(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction");
std::transform(
orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
auto orig_out_type = out_ins->get_shape().type();
if(contains(unsupported_types, orig_out_type))
{
auto gte_convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins);
return m.replace_instruction(out_ins, gte_convert);
}
else
{
return m.replace_instruction(out_ins, gte_ins);
}
});
}
else
{
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
void eliminate_data_type::apply(module& m) const void eliminate_data_type::apply(module& m) const
{ {
static const std::vector<std::string> skip_op_names = {"convert", static const std::vector<std::string> skip_op_names = {"convert",
...@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const ...@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add", "scatternd_add",
"scatternd_mul", "scatternd_mul",
"scatternd_none"}; "scatternd_none"};
if(unsupported_types.empty())
return;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
continue; continue;
if(contains(skip_op_names, ins->name())) if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name()))
continue;
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0)
return i;
return m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), i);
});
if(inputs == ins->inputs())
continue; continue;
auto op = ins->get_operator(); if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name()))
auto attributes = op.attributes(); insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto old_type = ins->get_shape().type();
auto out = m.insert_instruction(ins, op, inputs);
auto convert =
m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out);
m.replace_instruction(ins, convert);
} }
} }
......
...@@ -40,8 +40,9 @@ struct module; ...@@ -40,8 +40,9 @@ struct module;
*/ */
struct MIGRAPHX_EXPORT eliminate_data_type struct MIGRAPHX_EXPORT eliminate_data_type
{ {
std::set<shape::type_t> types; std::set<shape::type_t> unsupported_types;
shape::type_t target_type; shape::type_t target_type;
std::set<std::string> unsupported_ops = {"all"};
std::string name() const { return "eliminate_data_type"; } std::string name() const { return "eliminate_data_type"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -126,7 +126,6 @@ add_library(migraphx_gpu ...@@ -126,7 +126,6 @@ add_library(migraphx_gpu
fuse_ck.cpp fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp
gemm_impl.cpp gemm_impl.cpp
hip.cpp hip.cpp
kernel.cpp kernel.cpp
...@@ -140,7 +139,6 @@ add_library(migraphx_gpu ...@@ -140,7 +139,6 @@ add_library(migraphx_gpu
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
prefuse_ops.cpp prefuse_ops.cpp
pad.cpp
perfdb.cpp perfdb.cpp
pooling.cpp pooling.cpp
reverse.cpp reverse.cpp
...@@ -168,12 +166,10 @@ endfunction() ...@@ -168,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops(hip_ register_migraphx_gpu_ops(hip_
argmax argmax
argmin argmin
gather
logsoftmax logsoftmax
loop loop
multinomial multinomial
nonzero nonzero
pad
prefix_scan_sum prefix_scan_sum
reverse reverse
scatter scatter
...@@ -263,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT ...@@ -263,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning # Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API) check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
...@@ -292,10 +290,18 @@ else() ...@@ -292,10 +290,18 @@ else()
message(STATUS "rocBLAS does not have User Tuning Beta API") message(STATUS "rocBLAS does not have User Tuning Beta API")
endif() endif()
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis)
{
const auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
auto axis_dim_size = lens[axis];
lens[axis] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input_v) {
hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements, 256)([=](auto i) __device__ {
auto idx = out_comp.multi(i);
auto in_index = indices_ptr[idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis] = in_index;
output_ptr[i] = input[idx];
});
});
});
});
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{
std::size_t nelements = arg1.get_shape().elements();
hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
using hip_index = typename decltype(output)::hip_index;
type device_val = pad_clamp<host_type<type>>(value);
gs_launch(stream, result.get_shape().elements())(
[=](auto i) __device__ { output.data()[i] = device_val; });
hip_index offsets;
std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
gs_launch(stream, nelements)([=](auto i) __device__ {
auto idx = input.get_shape().multi(i);
for(std::size_t j = 0; j < offsets.size(); j++)
{
idx[j] += offsets[j];
}
output[idx] = input.data()[i];
});
});
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -114,10 +114,7 @@ struct mlir_op ...@@ -114,10 +114,7 @@ struct mlir_op
} }
if(ins->name() == "@return") if(ins->name() == "@return")
{ {
auto s = ins_shapes[ins->inputs().at(0)].with_type(type); return ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
} }
std::vector<shape> input_shapes; std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size()); input_shapes.resize(ins->inputs().size());
...@@ -139,8 +136,15 @@ get_fusable_input_op_stream(instruction_ref lower_input) ...@@ -139,8 +136,15 @@ get_fusable_input_op_stream(instruction_ref lower_input)
{ {
instruction_ref upper_input = lower_input; instruction_ref upper_input = lower_input;
std::vector<operation> op_stream; std::vector<operation> op_stream;
while( while(contains({"slice",
contains({"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"}, "transpose",
"multibroadcast",
"broadcast",
"contiguous",
"reshape",
"squeeze",
"flatten",
"unsqueeze"},
upper_input->name())) upper_input->name()))
{ {
operation op = upper_input->get_operator(); operation op = upper_input->get_operator();
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/gather.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_gather::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -22,11 +22,14 @@ ...@@ -22,11 +22,14 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
#include <type_traits>
using microseconds = std::chrono::duration<double, std::micro>; using microseconds = std::chrono::duration<double, std::micro>;
...@@ -34,6 +37,20 @@ namespace migraphx { ...@@ -34,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types // Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
...@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r; case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r; case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r; case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type: case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r;
case shape::tuple_type: case shape::tuple_type:
case shape::bool_type: case shape::bool_type:
case shape::uint16_type: case shape::uint16_type:
...@@ -183,12 +200,17 @@ struct gemm_impl ...@@ -183,12 +200,17 @@ struct gemm_impl
{ {
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
compute_type = output_type; compute_type = rb_compute_type{output_type};
if(compute_fp32) if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens(); auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens(); auto b_lens = input_shapes[1].lens();
...@@ -216,6 +238,34 @@ struct gemm_impl ...@@ -216,6 +238,34 @@ struct gemm_impl
} }
void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
{
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if(rocblas_fp8_available() and
std::any_of(input_args.begin(), input_args.end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
}
else
#endif
{ {
if(strided_batched) if(strided_batched)
{ {
...@@ -236,6 +286,7 @@ struct gemm_impl ...@@ -236,6 +286,7 @@ struct gemm_impl
gemm_flags); gemm_flags);
} }
} }
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const
...@@ -331,7 +382,6 @@ struct gemm_impl ...@@ -331,7 +382,6 @@ struct gemm_impl
num_matrices, num_matrices,
compute_type); compute_type);
} }
/** /**
* Helper method to create that subset of a long rocBLAS argument list that is common * Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls. * to multiple "gemm_ex..." calls.
...@@ -366,6 +416,7 @@ struct gemm_impl ...@@ -366,6 +416,7 @@ struct gemm_impl
ldd, ldd,
compute_type); compute_type);
} }
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/** /**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index * Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...@@ -481,8 +532,8 @@ struct gemm_impl ...@@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int b_stride = 0; rocblas_int b_stride = 0;
rocblas_int c_stride = 0; rocblas_int c_stride = 0;
rocblas_int d_stride = 0; rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r; rocblas_datatype arg_type = rocblas_datatype_f32_r;
rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r; rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true; bool strided_batched = true;
bool is_3inputs = true; bool is_3inputs = true;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument MIGRAPHX_DEVICE_EXPORT
gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument MIGRAPHX_DEVICE_EXPORT pad(hipStream_t stream,
argument result,
argument arg1,
float value,
std::vector<std::int64_t> pads);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_gather
{
op::gather op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::gather"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/pad.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_pad
{
op::pad op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::pad"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -40,6 +40,8 @@ struct context; ...@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag(); MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();
MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{ {
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
} }
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static constexpr __device__ fp8e5m2fnuz min() static constexpr __device__ fp8e5m2fnuz min()
{ {
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
...@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2> ...@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
} }
static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Support.h> #include <mlir-c/Support.h>
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 4
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck // Only undefine when not using cppcheck
#ifndef CPPCHECK #ifndef CPPCHECK
...@@ -319,31 +319,30 @@ struct mlir_program ...@@ -319,31 +319,30 @@ struct mlir_program
return result; return result;
} }
MlirType make_tensor(const shape& s) const MlirType make_mlir_shaped(const shape& s) const
{ {
if(not s.standard())
MIGRAPHX_THROW("MLIR expects all tensors to be in standard shape");
if(s.dynamic()) if(s.dynamic())
MIGRAPHX_THROW("MLIR does not support dynamic shapes"); MIGRAPHX_THROW("MLIR does not support dynamic shapes");
std::vector<int64_t> lens(s.lens().begin(), s.lens().end()); std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet( std::vector<int64_t> strides(s.strides().begin(), s.strides().end());
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull()); return rocmlirMIXRShapedTypeGet(
lens.size(), lens.data(), strides.data(), make_type(s.type()));
} }
template <class Range> template <class Range>
std::vector<MlirType> make_tensors(const Range& r) std::vector<MlirType> make_mlir_shapeds(const Range& r)
{ {
std::vector<MlirType> result; std::vector<MlirType> result;
std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) { std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) {
return make_tensor(s); return make_mlir_shaped(s);
}); });
return result; return result;
} }
MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs) MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs)
{ {
auto in = make_tensors(inputs); auto in = make_mlir_shapeds(inputs);
auto out = make_tensors(outputs); auto out = make_mlir_shapeds(outputs);
return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data()); return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data());
} }
...@@ -505,11 +504,7 @@ struct mlir_program ...@@ -505,11 +504,7 @@ struct mlir_program
mlir_operation_state& add_results(const std::vector<shape>& outputs) mlir_operation_state& add_results(const std::vector<shape>& outputs)
{ {
std::vector<shape> reshaped(outputs.size()); auto x = prog->make_mlir_shapeds(outputs);
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
if(not x.empty()) if(not x.empty())
{ {
mlirOperationStateAddResults(&op_state, x.size(), x.data()); mlirOperationStateAddResults(&op_state, x.size(), x.data());
...@@ -582,7 +577,7 @@ struct mlir_program ...@@ -582,7 +577,7 @@ struct mlir_program
std::vector<shape> outputs = m.get_output_shapes(); std::vector<shape> outputs = m.get_output_shapes();
std::vector<MlirLocation> arg_locs(inputs.size(), location); std::vector<MlirLocation> arg_locs(inputs.size(), location);
auto body_inputs = make_tensors(inputs); auto body_inputs = make_mlir_shapeds(inputs);
mlir_region region = mlirRegionCreate(); mlir_region region = mlirRegionCreate();
mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data()); mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data());
MlirBlock result = fbody.get(); MlirBlock result = fbody.get();
...@@ -608,7 +603,7 @@ struct mlir_program ...@@ -608,7 +603,7 @@ struct mlir_program
return "func.return"; return "func.return";
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
return "tosa.const"; return "migraphx.literal";
} }
return "migraphx." + ins->name(); return "migraphx." + ins->name();
} }
...@@ -667,7 +662,8 @@ struct mlir_program ...@@ -667,7 +662,8 @@ struct mlir_program
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
literal r = ins->get_literal(); literal r = ins->get_literal();
MlirType tensor_type = make_tensor(ins->get_shape()); MlirType shaped_type = make_mlir_shaped(ins->get_shape());
MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type);
MlirAttribute mlir_value_attr = MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data()); mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}}); ops.add_attributes({{"value", mlir_value_attr}});
...@@ -945,35 +941,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs) ...@@ -945,35 +941,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
auto param = m.get_parameter(name); auto param = m.get_parameter(name);
if(input.standard()) if(input.standard())
continue; continue;
auto lens = input.lens(); auto new_param = m.add_parameter(name + ".0", input);
auto strides = input.strides();
std::vector<operation> ops;
if(input.transposed())
{
auto perm = find_permutation(input);
auto iperm = invert_permutation(perm);
lens = reorder_dims(lens, iperm);
strides = reorder_dims(strides, iperm);
ops.push_back(make_op("transpose", {{"permutation", perm}}));
}
if(input.broadcasted())
{
std::transform(lens.begin(),
lens.end(),
strides.begin(),
lens.begin(),
[](auto len, auto stride) -> std::size_t {
if(stride == 0)
return 1;
return len;
});
ops.push_back(make_op("multibroadcast", {{"out_lens", input.lens()}}));
}
auto new_param =
std::accumulate(ops.begin(),
ops.end(),
m.add_parameter(name + ".0", shape{input.type(), lens}),
[&](auto x, auto op) { return m.insert_instruction(param, op, x); });
m.replace_instruction(param, new_param); m.replace_instruction(param, new_param);
m.remove_instruction(param); m.remove_instruction(param);
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/pad.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_pad::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs, *this}.has(1).standard();
return op.compute_shape(inputs);
}
argument hip_pad::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::pad(ctx.get_stream().get(), args.back(), args.front(), op.value, op.pads);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -53,6 +53,16 @@ bool get_compute_fp32_flag() ...@@ -53,6 +53,16 @@ bool get_compute_fp32_flag()
return (starts_with(device_name, "gfx9") and device_name >= "gfx908"); return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
} }
bool rocblas_fp8_available()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false;
#else
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
#endif
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -105,6 +105,21 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -105,6 +105,21 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type); unsupported_types.erase(shape::type_t::tuple_type);
std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
}
// add all device kernels
unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero");
unsupported_fp8_ops.insert("prefix_scan_sum");
unsupported_fp8_ops.insert("scatter_none");
unsupported_fp8_ops.insert("topk");
unsupported_fp8_ops.insert("rnn_var_sl_shift_output");
unsupported_fp8_ops.insert("multinomial");
unsupported_fp8_ops.insert("argmax");
unsupported_fp8_ops.insert("argmin");
// clang-format off // clang-format off
return return
{ {
...@@ -136,6 +151,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -136,6 +151,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{},
optimize_module{}, optimize_module{},
fuse_pointwise{}, fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -141,9 +141,9 @@ TEST_CASE(conv) ...@@ -141,9 +141,9 @@ TEST_CASE(conv)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
return %0 : tensor<1x2x2x2xf32> return %0 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -160,15 +160,38 @@ module { ...@@ -160,15 +160,38 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(conv_nhwc)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}, {128, 1, 32, 8}});
auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}, {72, 1, 24, 8}});
auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w);
m.add_return({conv});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(conv_add_relu) TEST_CASE(conv_add_relu)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution_add_relu(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution %arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add %0, %arg0 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu %1 : <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
return %2 : tensor<1x2x2x2xf32> return %2 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add) ...@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_quant_dot_add(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32> %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32> %1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1>
return %1 : tensor<1x5x3xi32> return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -219,10 +242,10 @@ TEST_CASE(dot_add) ...@@ -219,10 +242,10 @@ TEST_CASE(dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_add(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : tensor<1x5x3xf32> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize) ...@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32> %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1>
%1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32> %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1>
return %2 : tensor<1x2x2x2xi32> return %2 : !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -278,10 +301,10 @@ TEST_CASE(dot_convert) ...@@ -278,10 +301,10 @@ TEST_CASE(dot_convert)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_convert(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16> %1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1>
return %1 : tensor<1x5x3xf16> return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -304,10 +327,10 @@ TEST_CASE(dot_where) ...@@ -304,10 +327,10 @@ TEST_CASE(dot_where)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_where(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : tensor<1x5x3xf32> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment