"src/vscode:/vscode.git/clone" did not exist on "06fb090581efa63ad6cfff8f08bfec4ea0e857c3"
Unverified Commit 52585d4f authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents f0370072 d8011adf
......@@ -40,6 +40,8 @@ namespace op {
* 2. use_rank (default) vs use_len:
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* Uses the dynamic_dimension.max value for dynamic shapes. Returns the original vector
* (no normalization) if any of dynamic_dimension[axes] are not fixed.
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
......
......@@ -30,11 +30,11 @@
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <fenv.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
......@@ -71,26 +71,26 @@ struct quantizelinear
{
y_zero_point = args.at(2);
}
argument result{output_shape};
auto rounding_mode = fegetround();
fesetround(FE_TONEAREST);
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
visit_all(x, y_scale)([&](auto input, auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
});
});
fesetround(rounding_mode);
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
This diff is collapsed.
......@@ -84,6 +84,7 @@
#include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/nearbyint.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp>
......@@ -110,7 +111,6 @@
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/roialign.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter_add.hpp>
......
......@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if(input_shape.dynamic())
{
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if(std::any_of(axes.begin(), axes.end(), [&](auto ax) {
return not input_shape.dyn_dims().at(ax).is_fixed();
}))
{
return vec;
}
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i);
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
return input_shape.dyn_dims().at(i).max;
});
}
else
......
......@@ -97,10 +97,11 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13;
bool use_dyn_output = false;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t limit_max_iterations = std::numeric_limits<uint16_t>::max();
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops;
......
......@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
}
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
parser.limit_max_iterations = options.limit_max_iterations;
parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error)
......
......@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Neg", "neg"},
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
{"Round", "nearbyint"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Sin", "sin"},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_isinf : op_parser<parse_isinf>
{
std::vector<op_desc> operators() const { return {{"IsInf", "isinf"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
bool detect_negative = true;
bool detect_positive = true;
if(contains(info.attributes, "detect_negative"))
{
detect_negative = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_negative")).at<int>());
}
if(contains(info.attributes, "detect_positive"))
{
detect_positive = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_positive")).at<int>());
}
auto x_shape = args[0]->get_shape();
if(not detect_negative and not detect_positive)
{
return info.add_instruction(
make_op("multibroadcast", {{"out_lens", x_shape.lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{shape::bool_type}, {false}}));
}
auto is_inf = info.add_instruction(make_op("isinf"), args[0]);
if(detect_negative and detect_positive)
{
return is_inf;
}
auto zero_l = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto mb_zero =
info.add_instruction(make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), zero_l);
auto cond = info.add_broadcastable_binary_op(
detect_negative ? "less" : "greater", args[0], mb_zero);
if(cond->get_shape().type() != shape::bool_type)
{
cond =
info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), cond);
}
return info.add_instruction(make_op("logical_and"), is_inf, cond);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop>
}
}
// cap max_iter because loop uses static shapes with max_iter size and huge numbers
// here can cause overflow
if(max_iterations > parser.limit_max_iterations)
{
std::cerr << "WARNING: PARSE_LOOP max_iterations exceeds the maximum loop "
"iterations limit, it will be changed from "
<< max_iterations << " to " << parser.limit_max_iterations << ".\n";
max_iterations = parser.limit_max_iterations;
}
// condition input is empty
if(args.at(1)->name() == "undefined")
{
......
......@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
return nearest_mode;
}
static std::vector<double> get_scales(const onnx_parser::attribute_map& attr)
{
std::vector<double> scales;
if(contains(attr, "scales"))
{
copy(attr.at("scales").floats(), std::back_inserter(scales));
}
return scales;
}
static void parse_args(const std::vector<instruction_ref>& args,
const std::vector<size_t>& in_lens,
const std::string& op_name,
std::vector<double>& vec_scale,
std::vector<std::size_t>& out_lens)
{
for(const auto& arg : args)
{
if(arg->name() == "undefined" or arg == args.front())
{
continue;
}
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + op_name +
": specified output size does not match input size");
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
// scale input
if(lens[0] == in_lens.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_" + op_name + ": dynamic input scale is not supported!");
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
}
}
}
}
struct parse_resize : op_parser<parse_resize>
{
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
......@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize>
std::vector<std::size_t> out_lens(in_lens.size());
// scale
std::vector<double> vec_scale;
std::vector<double> vec_scale = get_scales(info.attributes);
for(const auto& arg : args)
// If `scales` was not an attribute, it must be an input
if(vec_scale.empty())
{
if(arg->name() == "undefined" or arg == args.front())
{
continue;
}
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
}
// Depending on the args, it *must* populate the `vec_scale`, and might populate
// `out_lens`
parse_args(args, in_lens, opd.op_name, vec_scale, out_lens);
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": ranks of input and scale are different!");
}
// scale input
if(lens[0] == in_lens.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_" + opd.op_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
}
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale);
});
}
}
// if the output was not calculated yet, we update it based on the scales
if(all_of(out_lens.cbegin(), out_lens.cend(), [](auto o) { return o == 0; }))
{
std::transform(
in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
}
shape out_s{in_s.type(), out_lens};
......@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
if(mode == "nearest")
......
......@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); }
/**
* Either insert argument into `this->op_args` or return the constant value of the argument
*/
std::vector<int64_t> insert(instruction_ref arg)
{
std::vector<int64_t> result;
......@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice>
sd.op.axes = axes;
}
if(not sd.steps.empty())
if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return s != 1; }))
{
if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported");
MIGRAPHX_THROW(
"PARSE_SLICE: steps and variable starts and/or ends is not supported");
if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
}
assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(sd.steps.size()))
{
......
......@@ -472,7 +472,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
map_dyn_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
int64_t max_loop_iterations,
int64_t limit_max_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
......@@ -481,6 +482,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
options.limit_max_iterations = limit_max_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
......@@ -492,7 +494,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
py::arg("max_loop_iterations") = 10,
py::arg("limit_max_iterations") = std::numeric_limits<uint16_t>::max());
m.def(
"parse_onnx_buffer",
......
......@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3)
{
......
......@@ -24,6 +24,7 @@
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -131,10 +132,53 @@ struct find_const_4in_slice
}
};
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
match::find_matches(m,
find_static_2in_broadcasts{},
find_static_dimensions_of{},
find_const_3in_slice{},
find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -647,8 +647,8 @@ struct find_broadcast_transpose
{
auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// scalar transformation does not need extra transpose
if(not input->get_shape().scalar())
{
......
# ####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -245,10 +245,14 @@ else()
endif()
# Check miopen find mode api
include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
......@@ -271,6 +275,13 @@ else()
message(STATUS "MIOpen does not have find mode api")
endif()
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
......
......@@ -168,6 +168,7 @@ struct compile_plan
}
const compiled_result& benchmark(problem_cache& pc) const
{
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
......@@ -178,9 +179,10 @@ struct compile_plan
}
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 0)
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(trace_level > 1)
std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times;
times.reserve(results.size());
......@@ -189,22 +191,23 @@ struct compile_plan
config->solutions.begin(),
std::back_inserter(times),
[&](const auto& cr, const auto& solution) {
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value())
{
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max();
}
auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << t << "ms" << std::endl;
return t;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
if(trace_level > 0)
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation.");
......
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
void blas_shape(const shape& s);
template <class Op>
struct rocblas_gemm
......@@ -52,6 +51,7 @@ struct rocblas_gemm
float beta = 0;
bool compute_fp32 = false;
unsigned trans_batch = 0;
int32_t solution_idx = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -60,7 +60,8 @@ struct rocblas_gemm
pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
f(self.trans_batch, "trans_batch"),
f(self.solution_idx, "solution_idx")));
}
std::string name() const
......@@ -76,6 +77,8 @@ struct rocblas_gemm
{
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
// When input shapes are A, B, C the GEMM equation is C  =  α AB+ β C where α, β are
// scalars
check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]);
blas_shape(inputs[1]);
......@@ -111,11 +114,12 @@ struct rocblas_gemm
{
if(this->name() == "gpu::gemm")
{
gemm(ctx, output_shape, args, alpha, beta, compute_fp32);
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
else
{
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32);
gemm_compute(
ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32, solution_idx);
}
return args.back();
}
......@@ -124,6 +128,33 @@ struct rocblas_gemm
{
return shapes.size() - 1;
}
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& input_shapes)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag())
{
if(this->name() == "gpu::gemm")
{
solution_idx = gemm_finalize(
ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx);
}
else
{
solution_idx = gemm_finalize(ctx,
output_shape,
input_shapes,
int32_t(alpha),
int32_t(beta),
compute_fp32,
solution_idx);
}
}
#else
// suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes;
#endif
}
};
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment