"IMG/cpio/ventoy/hook/ploplinux/disk_hook.sh" did not exist on "d5b829f8e8c8367b032b4bb57a8fc37701d42e17"
Unverified Commit 0662a9a3 authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Merge branch 'develop' into dyn_resize_gather

parents b74d3a8f 35e5298e
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -24,28 +24,64 @@ ...@@ -24,28 +24,64 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP #define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#include <iterator>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_GEMM_TUNING);
using milliseconds = std::chrono::duration<double, std::milli>;
using microseconds = std::chrono::duration<double, std::micro>;
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void gemm(context& ctx, /**
const shape& output_shape, * @brief Templated implementations of the compute() and finalize() methods of the Gemm operator.
const std::vector<argument>& args, * For each function there are overloads using either float or int32_t for the arguments
float alpha, * alpha and beta.
float beta, *
bool int8_x4_format, * @param ctx .
bool compute_fp32); * @param output_shape .
void gemm(context& ctx, * @param args .
const shape& output_shape, * @param alpha .
const std::vector<argument>& args, * @param beta .
int32_t alpha, * @param compute_fp32 .
int32_t beta, */
bool int8_x4_format, void gemm_compute(context& ctx,
bool compute_fp32); const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx);
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx);
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool compute_fp32);
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_int8_conv_pack
{
std::string name() const { return "gpu::int8_conv_pack"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -127,7 +127,7 @@ inline void set_tensor_descriptor(miopenTensorArgumentId_t name, ...@@ -127,7 +127,7 @@ inline void set_tensor_descriptor(miopenTensorArgumentId_t name,
} }
#endif #endif
inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = false) inline tensor_descriptor make_tensor(const migraphx::shape& os)
{ {
auto s = os.normalize_standard(); auto s = os.normalize_standard();
auto t = make_obj<tensor_descriptor>(&miopenCreateTensorDescriptor); auto t = make_obj<tensor_descriptor>(&miopenCreateTensorDescriptor);
...@@ -142,23 +142,9 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = fals ...@@ -142,23 +142,9 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = fals
else if(s.type() == shape::int32_type) else if(s.type() == shape::int32_type)
d = miopenInt32; d = miopenInt32;
else if(s.type() == shape::int8_type) else if(s.type() == shape::int8_type)
{ d = miopenInt8;
if(pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
}
else else
{
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type"); MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data()); miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
return t; return t;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,8 +40,6 @@ struct context; ...@@ -40,8 +40,6 @@ struct context;
MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag(); MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();
MIGRAPHX_GPU_EXPORT bool get_int8_x4_format(context& ctx);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape pack_int8_shape(const shape& s)
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return pack_int8_shape(inputs.at(0));
}
argument
miopen_int8_conv_pack::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto arg_desc = make_tensor(args[0].get_shape());
auto arg_desc_vec4 = make_tensor(args[0].get_shape(), true);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
arg_desc.get(),
args[0].implicit(),
&beta,
arg_desc_vec4.get(),
args[1].implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("INT8_CONV_PACK: transform input tensor failed");
}
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_int8_gemm_pack_a::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_a::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_a(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
shape hip_int8_gemm_pack_b::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_b::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_b(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -101,7 +101,9 @@ MIGRAPHX_DEVICE_MATH(erf, ::erf) ...@@ -101,7 +101,9 @@ MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp) MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor) MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan) MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(isinf, ::isinf)
MIGRAPHX_DEVICE_MATH(log, ::log) MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(nearbyint, ::nearbyint)
MIGRAPHX_DEVICE_MATH(pow, ::pow) MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(remainder, ::remainder) MIGRAPHX_DEVICE_MATH(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH(round, ::round) MIGRAPHX_DEVICE_MATH(round, ::round)
...@@ -135,6 +137,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, ceil, ::hceil) ...@@ -135,6 +137,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, ceil, ::hceil)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, cos, ::hcos) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, cos, ::hcos)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isinf, ::__hisinf)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
...@@ -150,6 +153,7 @@ MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan) ...@@ -150,6 +153,7 @@ MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh) MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh) MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(nearbyint, ::nearbyint)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder) MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
...@@ -229,10 +233,12 @@ MIGRAPHX_DEVICE_MATH_VEC(erf) ...@@ -229,10 +233,12 @@ MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp) MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor) MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(fmod) MIGRAPHX_DEVICE_MATH_VEC(fmod)
MIGRAPHX_DEVICE_MATH_VEC(isinf)
MIGRAPHX_DEVICE_MATH_VEC(isnan) MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max) MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min) MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(nearbyint)
MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(remainder) MIGRAPHX_DEVICE_MATH_VEC(remainder)
MIGRAPHX_DEVICE_MATH_VEC(round) MIGRAPHX_DEVICE_MATH_VEC(round)
......
...@@ -61,9 +61,8 @@ struct miopen_apply ...@@ -61,9 +61,8 @@ struct miopen_apply
const lowering* pass = nullptr; const lowering* pass = nullptr;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{}; instruction_ref last{};
bool offload_copy = false; bool offload_copy = false;
bool int8_x4_format = true; bool compute_fp32 = false;
bool compute_fp32 = false;
context& get_context() const context& get_context() const
{ {
...@@ -84,10 +83,8 @@ struct miopen_apply ...@@ -84,10 +83,8 @@ struct miopen_apply
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
auto& ctx = get_context(); compute_fp32 = get_compute_fp32_flag();
int8_x4_format = get_int8_x4_format(ctx); offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
compute_fp32 = get_compute_fp32_flag();
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
add_extend_op("argmax"); add_extend_op("argmax");
...@@ -231,18 +228,15 @@ struct miopen_apply ...@@ -231,18 +228,15 @@ struct miopen_apply
assert(refs.size() == 2); assert(refs.size() == 2);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
refs.push_back(output); refs.push_back(output);
return mod->replace_instruction( return mod->replace_instruction(ins, rocblas_gemm<Op>{Op{}, 1, 0, compute_fp32}, refs);
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs);
}); });
} }
void add_convolution_op(const std::string& name) void add_convolution_op(const std::string& name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(name, [=](instruction_ref ins) {
operation conv = make_op( operation conv = make_op("gpu::" + name, {{"op", ins->get_operator().to_value()}});
"gpu::" + name, auto output = insert_allocation(ins, ins->get_shape());
{{"op", ins->get_operator().to_value()}, {"int8_x4_format", int8_x4_format}});
auto output = insert_allocation(ins, ins->get_shape());
return mod->replace_instruction(ins, return mod->replace_instruction(ins,
make_op("gpu::miopen_op", {{"op", to_value(conv)}}), make_op("gpu::miopen_op", {{"op", to_value(conv)}}),
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iterator>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static instruction_ref pad_ins(module& m, instruction_ref ins, int offset)
{
auto s = ins->get_shape();
auto lens = s.lens();
auto k = lens[lens.size() + offset];
auto pad_k = (k + 3) / 4 * 4;
auto pad_lens = lens;
pad_lens[lens.size() + offset] = pad_k;
auto ret_ins = ins;
if(pad_k != k)
{
std::vector<int64_t> pad_dims(lens.size() * 2, 0);
pad_dims[lens.size() + offset] = pad_k - k;
shape ps{s.type(), pad_lens};
auto ins_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(ps)}}));
auto pad = make_op("pad", {{"pads", pad_dims}});
ret_ins =
m.insert_instruction(std::next(ins), make_op("gpu::pad", pad.to_value()), ins, ins_out);
}
return ret_ins;
}
static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
{
std::vector<instruction_ref> ret_inputs;
auto inputs = ins->inputs();
auto in0 = inputs.at(0);
auto sa = in0->get_shape();
bool transa = sa.transposed();
if(transa)
{
auto perm = find_permutation(sa);
auto val = in0->get_operator().to_value();
if(val.contains("dims"))
{
int offset = static_cast<int>(perm.back()) - static_cast<int>(perm.size());
auto t_in = in0->inputs().front();
auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>();
auto r_in =
m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in);
ret_inputs.push_back(r_in);
}
else
{
shape cs{in0->get_shape().type(), in0->get_shape().lens()};
auto con_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(cs)}}));
auto cin0 = m.insert_instruction(ins, make_op("gpu::contiguous"), in0, con_out);
ret_inputs.push_back(pad_ins(m, cin0, -1));
}
}
else
{
ret_inputs.push_back(pad_ins(m, in0, -1));
}
auto in1 = inputs.at(1);
auto sb = in1->get_shape();
bool transb = sb.transposed();
if(transb)
{
auto perm = find_permutation(sb);
auto val = in1->get_operator().to_value();
if(val.contains("dims"))
{
int offset = static_cast<int>(perm[perm.size() - 2]) - static_cast<int>(perm.size());
auto t_in = in1->inputs().front();
auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>();
auto r_in =
m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in);
ret_inputs.push_back(r_in);
}
else
{
shape cs{in1->get_shape().type(), in1->get_shape().lens()};
auto con_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(cs)}}));
auto cin1 = m.insert_instruction(ins, make_op("gpu::contiguous"), in1, con_out);
ret_inputs.push_back(pad_ins(m, cin1, -2));
}
}
else
{
ret_inputs.push_back(pad_ins(m, in1, -2));
}
std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(ret_inputs));
return ret_inputs;
}
void pack_int8_args::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "gpu::quant_gemm")
{
auto val = ins->get_operator().to_value();
assert(val.contains("int8_x4_format"));
if(not val.at("int8_x4_format").to<bool>())
{
continue;
}
auto inputs = ins->inputs();
auto lens = inputs.at(0)->get_shape().lens();
// gemm need the k to be multiple of 4, so need packing that dimension
auto old_inputs = inputs;
if((lens.back() % 4) != 0)
{
inputs = pad_inputs(m, ins);
}
bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed();
if(not transb)
{
auto packed_b = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}}));
auto output_b = m.insert_instruction(
ins, make_op("gpu::int8_gemm_pack_a"), {inputs[1], packed_b});
inputs[1] = output_b;
}
if(transa)
{
auto packed_a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(inputs[0]->get_shape())}}));
auto output_a = m.insert_instruction(
ins, make_op("gpu::int8_gemm_pack_b"), {inputs[0], packed_a});
inputs[0] = output_a;
}
if(inputs != old_inputs)
{
m.replace_instruction(ins, ins->get_operator(), inputs);
}
}
else if(ins->name() == "gpu::quant_convolution")
{
auto val = ins->get_operator().to_value();
if(not val.at("int8_x4_format").to<bool>())
{
continue;
}
auto inputs = ins->inputs();
auto packed_x = m.insert_instruction(
ins,
make_op("hip::allocate",
{{"shape", to_value(pack_int8_shape(inputs[0]->get_shape()))}}));
auto output_x =
m.insert_instruction(ins, make_op("gpu::int8_conv_pack"), {inputs[0], packed_x});
instruction::replace_argument(ins, inputs[0], output_x);
auto packed_w = m.insert_instruction(
ins,
make_op("hip::allocate",
{{"shape", to_value(pack_int8_shape(inputs[1]->get_shape()))}}));
auto output_w =
m.insert_instruction(ins, make_op("gpu::int8_conv_pack"), {inputs[1], packed_w});
instruction::replace_argument(ins, inputs[1], output_w);
}
}
}
shape pack_int8_args::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -53,19 +53,6 @@ bool get_compute_fp32_flag() ...@@ -53,19 +53,6 @@ bool get_compute_fp32_flag()
return (starts_with(device_name, "gfx9") and device_name >= "gfx908"); return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
} }
bool get_int8_x4_format(context& ctx)
{
#if ROCBLAS_VERSION_MAJOR >= 3
(void)(ctx);
return false;
#else
// int8x4 packed format is only available starting from rocblas-v2.38 and it is deprecated in
// v3.0 and will be removed in v4.0
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
return flag == rocblas_gemm_flags_pack_int8x4;
#endif
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -63,7 +63,6 @@ ...@@ -63,7 +63,6 @@
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/sync_device.hpp> #include <migraphx/gpu/sync_device.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -154,7 +153,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -154,7 +153,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
compile_miopen{&gctx}, compile_miopen{&gctx},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -25,5 +25,5 @@ ...@@ -25,5 +25,5 @@
#define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@ #define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@
#define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@ #define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@
#define MIGRAPHX_VERSION_PATCH @PROJECT_VERSION_PATCH@ #define MIGRAPHX_VERSION_PATCH @PROJECT_VERSION_PATCH@
#define MIGRAPHX_VERSION_TWEAK @PROJECT_VERSION_TWEAK@ #define MIGRAPHX_VERSION_TWEAK "@PROJECT_VERSION_TWEAK@"
// clang-format on // clang-format on
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
cmake_policy(SET CMP0057 NEW) cmake_policy(SET CMP0057 NEW)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
rocm_test_link_libraries(Threads::Threads migraphx migraphx_ref migraphx_onnx migraphx_tf) rocm_test_link_libraries(Threads::Threads migraphx migraphx_onnx migraphx_tf)
rocm_test_include_directories(include) rocm_test_include_directories(include)
set(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS Off CACHE BOOL "") set(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS Off CACHE BOOL "")
...@@ -146,7 +146,10 @@ endfunction() ...@@ -146,7 +146,10 @@ endfunction()
function(test_headers PREFIX) function(test_headers PREFIX)
file(GLOB HEADERS CONFIGURE_DEPENDS ${ARGN}) file(GLOB HEADERS CONFIGURE_DEPENDS ${ARGN})
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM HEADERS
${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp)
endif()
foreach(HEADER ${HEADERS}) foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
......
...@@ -30,6 +30,9 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -30,6 +30,9 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
if(WIN32)
target_compile_definitions(${NAME} PRIVATE _CRT_SECURE_NO_WARNINGS)
endif()
endfunction() endfunction()
# Workaround: C file dont work with clang-tidy right now, need a fix in rocm-cmake # Workaround: C file dont work with clang-tidy right now, need a fix in rocm-cmake
...@@ -41,6 +44,9 @@ function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -41,6 +44,9 @@ function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
if(WIN32)
target_compile_definitions(${NAME} PRIVATE _CRT_SECURE_NO_WARNINGS)
endif()
endfunction() endfunction()
add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR}) add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR})
...@@ -57,10 +63,6 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR}) ...@@ -57,10 +63,6 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR}) add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
# GPU-based tests # GPU-based tests
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_gpu)
add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_custom_op_gpu)
endif() endif()
...@@ -198,4 +198,29 @@ TEST_CASE(set_loop_default_iter_num) ...@@ -198,4 +198,29 @@ TEST_CASE(set_loop_default_iter_num)
EXPECT(out_shapes[1].lengths() == out_lens1); EXPECT(out_shapes[1].lengths() == out_lens1);
} }
TEST_CASE(set_loop_limit_iterations)
{
migraphx::onnx_options option;
option.set_default_loop_iterations(15);
option.set_limit_loop_iterations(10);
auto p = migraphx::parse_onnx("loop_default_test.onnx", option);
auto out_shapes = p.get_output_shapes();
std::vector<std::size_t> out_lens0 = {1};
EXPECT(out_shapes[0].lengths() == out_lens0);
std::vector<std::size_t> out_lens1 = {10, 1};
EXPECT(out_shapes[1].lengths() == out_lens1);
}
TEST_CASE(set_loop_limit_iterations2)
{
migraphx::onnx_options option;
option.set_limit_loop_iterations(10);
auto p = migraphx::parse_onnx("loop_test_implicit_tripcnt.onnx", option);
auto out_shapes = p.get_output_shapes();
std::vector<std::size_t> out_lens0 = {1};
EXPECT(out_shapes[0].lengths() == out_lens0);
std::vector<std::size_t> out_lens1 = {10, 1};
EXPECT(out_shapes[1].lengths() == out_lens1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -317,4 +317,59 @@ TEST_CASE(loop_test) ...@@ -317,4 +317,59 @@ TEST_CASE(loop_test)
} }
} }
TEST_CASE(loop_test_limit_max_iter)
{
auto run_prog = [&](int64_t limit_max_iterations) {
migraphx::onnx_options parse_options;
parse_options.set_limit_loop_iterations(limit_max_iterations);
auto p = migraphx::parse_onnx("loop_test_implicit_tripcnt.onnx", parse_options);
auto shapes_before = p.get_output_shapes();
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 2);
CHECK(bool{shapes_before.front() == shapes_after.front()});
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
auto aas = param_shapes["a"];
std::vector<float> xd = {1.0f};
pp.add("a", migraphx::argument(aas, xd.data()));
auto bbs = param_shapes["b"];
std::vector<float> yd = {2.0};
pp.add("b", migraphx::argument(bbs, yd.data()));
auto cs = param_shapes["keep_going_cond"];
bool cond = true;
pp.add("keep_going_cond", migraphx::argument(cs, &cond));
auto outputs = p.eval(pp);
auto output = outputs[0];
std::vector<std::vector<float>> ret;
ret.push_back(output.as_vector<float>());
output = outputs[1];
ret.push_back(output.as_vector<float>());
return ret;
};
{
auto result_vector = run_prog(5);
std::vector<float> gold0 = {2.0f};
EXPECT(result_vector.at(0) == gold0);
std::vector<float> gold1 = {-2, 4, 0, 0, 0};
EXPECT(result_vector.at(1) == gold1);
}
{
auto result_vector = run_prog(20);
std::vector<float> gold0 = {2.0f};
EXPECT(result_vector.at(0) == gold0);
std::vector<float> gold1 = {-2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(result_vector.at(1) == gold1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -64,7 +64,7 @@ TEST_CASE(mul_literal_round_test) ...@@ -64,7 +64,7 @@ TEST_CASE(mul_literal_round_test)
auto l1 = mm->add_literal(1 / 0.00787402f); auto l1 = mm->add_literal(1 / 0.00787402f);
auto mul = mm->add_instruction(migraphx::make_op("mul"), l0, l1); auto mul = mm->add_instruction(migraphx::make_op("mul"), l0, l1);
auto round = mm->add_instruction(migraphx::make_op("round"), mul); auto round = mm->add_instruction(migraphx::make_op("nearbyint"), mul);
mm->add_return({round}); mm->add_return({round});
......
...@@ -152,6 +152,9 @@ TEST_CASE(int_quant_dot_tanh_fails) ...@@ -152,6 +152,9 @@ TEST_CASE(int_quant_dot_tanh_fails)
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
{ {
test::run(argc, argv); if(migraphx::gpu::mlir_enabled())
{
test::run(argc, argv);
}
return 0; return 0;
} }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iostream>
#include <vector>
#include <migraphx/gpu/gemm.hpp>
#include <hip/hip_runtime_api.h>
#include <migraphx/gpu/target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
// includes needed for run_lowering
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
// Abbreviated lowering; we don't need the usual cleanup passes for this test
void run_lowering(migraphx::program& p, bool offload_copy = false)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(
*p.get_main_module(),
{migraphx::auto_contiguous{}, migraphx::gpu::lowering{&ctx, offload_copy}});
}
/**
* Tests the automatic GEMM tuning feature. In the finalize() method of the gemm op,
* rocBLAS API functions are called to quickly benchmark all the GEMM solutions
* available in the currently installed rocBLAS library and choose the index of the fastest.
*/
TEST_CASE(gemm_tune_with_rocblas)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {4, 2}};
migraphx::shape sb{migraphx::shape::float_type, {2, 3}};
auto a = mm->add_parameter("a", sa);
auto b = mm->add_parameter("b", sb);
migraphx::operation dot_op = migraphx::make_op("dot");
mm->add_instruction(dot_op, a, b);
// lowering adds gemm implementation for dot operator
run_lowering(p);
migraphx::target gpu_t = migraphx::gpu::target{};
migraphx::compile_options options;
options.exhaustive_tune = true;
p.compile(gpu_t, options);
migraphx::value solution_idx(0);
for(auto ins : iterator_for(*p.get_main_module()))
{
if(ins->name() == "gpu::gemm")
{
auto gemm_op = migraphx::get_operation(ins);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
// gemm_op.to_value().debug_print();
solution_idx = gemm_op.to_value()["solution_idx"];
break;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT(0 != solution_idx.to<std::size_t>());
#else
EXPECT(0 == solution_idx.to<std::size_t>());
#endif
}
// GEMM tuning of a strided-batch matrix; invokes rocblas_gemm_strided_batched_ex
TEST_CASE(gemm_tune_strided)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {4, 2, 2}};
migraphx::shape sb{migraphx::shape::float_type, {4, 2, 2}};
migraphx::shape s_output{migraphx::shape::float_type, {4, 2, 2}};
auto a = mm->add_parameter("a", sa);
auto b = mm->add_parameter("b", sb);
auto output = mm->add_parameter("out", s_output);
auto gemm_oper = migraphx::make_op("gpu::gemm", {{"beta", 2}});
mm->add_instruction(gemm_oper, a, b, output);
migraphx::target gpu_t = migraphx::gpu::target{};
migraphx::compile_options options;
options.exhaustive_tune = true;
p.compile(gpu_t, options);
migraphx::value solution_idx(0);
for(auto ins : iterator_for(*p.get_main_module()))
{
if(ins->name() == "gpu::gemm")
{
auto gemm_op = migraphx::get_operation(ins);
auto gemmv = gemm_op.to_value();
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx = gemm_op.to_value()["solution_idx"];
break;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT(0 != solution_idx.to<std::size_t>());
#else
EXPECT(0 == solution_idx.to<std::size_t>());
#endif
}
// GEMM tuning of a strided-batch matrix; created by lowering
TEST_CASE(gemm_tune_strided_lowered)
{
migraphx::program p;
auto* mm = p.get_main_module();
// At time of writing this test, gemm_impl considers a shape is strided if it has
// at least three dimensions and the 3rd-to-last is nonzero, invoking
// rocblas_gemm_strided_batched_ex. Also, DOT operator requires all dimensions except the last
// two to be equal.
migraphx::shape sa{migraphx::shape::float_type, {4, 2, 5}};
migraphx::shape sb{migraphx::shape::float_type, {4, 5, 3}};
auto a = mm->add_parameter("a", sa);
auto b = mm->add_parameter("b", sb);
migraphx::operation dot_op = migraphx::make_op("dot");
mm->add_instruction(dot_op, a, b);
// lowering adds gemm implementation for dot operator
run_lowering(p);
migraphx::target gpu_t = migraphx::gpu::target{};
migraphx::compile_options options;
options.exhaustive_tune = true;
p.compile(gpu_t, options);
migraphx::value solution_idx(0);
for(auto ins : iterator_for(*p.get_main_module()))
{
if(ins->name() == "gpu::gemm")
{
auto gemm_op = migraphx::get_operation(ins);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx = gemm_op.to_value()["solution_idx"];
break;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT(0 != solution_idx.to<std::size_t>());
#else
EXPECT(0 == solution_idx.to<std::size_t>());
#endif
}
TEST_CASE(gemm_tune_invalid_sol_index)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {4, 2}};
migraphx::shape sb{migraphx::shape::float_type, {2, 3}};
migraphx::shape s_output{migraphx::shape::float_type, {4, 3}};
auto a = mm->add_parameter("a", sa);
auto b = mm->add_parameter("b", sb);
auto output = mm->add_parameter("out", s_output);
auto gemm_oper = migraphx::make_op("gpu::gemm", {{"solution_idx", 987654321}});
mm->add_instruction(gemm_oper, a, b, output);
migraphx::target gpu_t = migraphx::gpu::target{};
migraphx::compile_options options;
options.exhaustive_tune = true;
p.compile(gpu_t, options);
migraphx::value solution_idx(0);
for(auto ins : iterator_for(*p.get_main_module()))
{
if(ins->name() == "gpu::gemm")
{
auto gemm_op = migraphx::get_operation(ins);
auto gemmv = gemm_op.to_value();
// given invalid starting index, should return default 0
solution_idx = gemm_op.to_value()["solution_idx"];
break;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT(0 == solution_idx.to<std::size_t>());
#else
EXPECT(0 != solution_idx.to<std::size_t>());
#endif
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment