Unverified Commit 3685906c authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into refactor_auto_pad_conv

parents 2b936b13 e19f78ae
...@@ -104,6 +104,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor) ...@@ -104,6 +104,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan) MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(log, ::log) MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow) MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH(round, ::round) MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt) MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin) MIGRAPHX_DEVICE_MATH(sin, ::sin)
...@@ -111,6 +112,7 @@ MIGRAPHX_DEVICE_MATH(sinh, ::sinh) ...@@ -111,6 +112,7 @@ MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt) MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan) MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh) MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH(fmod, ::fmod)
// Float overloads // Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf) MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
...@@ -126,6 +128,7 @@ MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf) ...@@ -126,6 +128,7 @@ MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf) MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf) MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf) MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions // Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
...@@ -148,11 +151,13 @@ MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) ...@@ -148,11 +151,13 @@ MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor) MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan) MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin) MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh) MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
...@@ -226,11 +231,13 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh) ...@@ -226,11 +231,13 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf) MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp) MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor) MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(fmod)
MIGRAPHX_DEVICE_MATH_VEC(isnan) MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max) MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min) MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(remainder)
MIGRAPHX_DEVICE_MATH_VEC(round) MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt) MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin) MIGRAPHX_DEVICE_MATH_VEC(sin)
......
...@@ -81,77 +81,21 @@ struct miopen_apply ...@@ -81,77 +81,21 @@ struct miopen_apply
(void)i; (void)i;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 auto& ctx = get_context();
auto& ctx = get_context(); int8_x4_format = get_int8_x4_format(ctx);
const auto device_name = trim(split_string(get_device_name(), ':').front()); compute_fp32 = get_compute_fp32_flag();
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
add_generic_op("acos");
add_generic_op("acosh");
add_generic_op("add");
add_generic_op("asin");
add_generic_op("asinh");
add_generic_op("atan");
add_generic_op("atanh");
add_generic_op("ceil");
add_generic_op("contiguous"); add_generic_op("contiguous");
add_generic_op("cos");
add_generic_op("cosh");
add_generic_op("div");
add_generic_op("equal");
add_generic_op("erf");
add_generic_op("exp");
add_generic_op("floor");
add_generic_op("greater");
add_generic_op("less");
add_generic_op("log");
add_generic_op("logical_and");
add_generic_op("logical_or");
add_generic_op("logical_xor");
add_generic_op("max");
add_generic_op("min");
add_generic_op("mul");
add_generic_op("not");
add_generic_op("pow");
add_generic_op("prelu");
add_generic_op("recip");
add_generic_op("relu");
add_generic_op("round");
add_generic_op("rsqrt");
add_generic_op("sigmoid");
add_generic_op("sign");
add_generic_op("sin");
add_generic_op("sinh");
add_generic_op("sqdiff");
add_generic_op("sqrt");
add_generic_op("sub");
add_generic_op("tan");
add_generic_op("tanh");
add_generic_op("where");
add_extend_op("abs");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("clip");
add_extend_op("convert");
add_extend_op("elu"); add_extend_op("elu");
add_extend_op("gather"); add_extend_op("gather");
add_extend_op("leaky_relu"); add_extend_op("leaky_relu");
......
...@@ -35,6 +35,12 @@ namespace { ...@@ -35,6 +35,12 @@ namespace {
template <class Derived, std::size_t N> template <class Derived, std::size_t N>
struct layernorm_base struct layernorm_base
{ {
float epsilon = 1e-12f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.epsilon, "epsilon"));
}
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = 1;
...@@ -62,6 +68,7 @@ struct layernorm_base ...@@ -62,6 +68,7 @@ struct layernorm_base
struct layernorm : layernorm_base<layernorm, 0> struct layernorm : layernorm_base<layernorm, 0>
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
}; };
MIGRAPHX_REGISTER_OP(layernorm); MIGRAPHX_REGISTER_OP(layernorm);
...@@ -80,8 +87,9 @@ struct find_layernorm ...@@ -80,8 +87,9 @@ struct find_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{}, x_ins); m.replace_instruction(ins, layernorm{eps}, x_ins);
} }
}; };
...@@ -96,8 +104,9 @@ struct find_add_layernorm ...@@ -96,8 +104,9 @@ struct find_add_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
} }
}; };
} // namespace } // namespace
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/quant_convolution.hpp> #include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
......
...@@ -21,7 +21,13 @@ ...@@ -21,7 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <unordered_set>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s) ...@@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return rb; return rb;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
bool get_compute_fp32_flag()
{
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
#endif
return compute_fp32;
}
bool get_int8_x4_format(context& ctx)
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.normalize_compute_shape({inputs.at(0)});
}
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::softmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -347,7 +347,7 @@ void tf_parser::parse_node(const std::string& name) ...@@ -347,7 +347,7 @@ void tf_parser::parse_node(const std::string& name)
// input was from a node with multiple outputs // input was from a node with multiple outputs
if(contains(input_name, ':')) if(contains(input_name, ':'))
{ {
input_name = input_name.substr(0, input.find(':')); input_name.resize(input.find(':'));
} }
else else
{ {
......
...@@ -511,14 +511,7 @@ void print_value(std::ostream& os, const std::vector<value>& x) ...@@ -511,14 +511,7 @@ void print_value(std::ostream& os, const std::vector<value>& x)
os << "}"; os << "}";
} }
void print_value(std::ostream& os, const value::binary& x) void print_value(std::ostream& os, const value::binary& x) { os << x; }
{
// Convert binary to integers
std::vector<int> v(x.begin(), x.end());
os << "{";
os << to_string_range(v);
os << "}";
}
std::ostream& operator<<(std::ostream& os, const value& d) std::ostream& operator<<(std::ostream& os, const value& d)
{ {
......
...@@ -40,6 +40,10 @@ ...@@ -40,6 +40,10 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
#include "make_precompile_op.hpp"
// Treat some operators as compilable to enable lowering
MIGRAPHX_GPU_TEST_PRECOMPILE("add", "mul", "convert")
void run_lowering(migraphx::program& p, bool offload_copy = false) void run_lowering(migraphx::program& p, bool offload_copy = false)
{ {
...@@ -118,7 +122,7 @@ TEST_CASE(no_copy_dead_param) ...@@ -118,7 +122,7 @@ TEST_CASE(no_copy_dead_param)
auto xb = mm->add_instruction(migraphx::make_op("hip::allocate", {{"shape", to_value(s)}})); auto xb = mm->add_instruction(migraphx::make_op("hip::allocate", {{"shape", to_value(s)}}));
auto gx = mm->add_instruction(migraphx::make_op("hip::copy_to_gpu"), x, xb); auto gx = mm->add_instruction(migraphx::make_op("hip::copy_to_gpu"), x, xb);
auto ab = mm->add_instruction(migraphx::make_op("hip::allocate", {{"shape", to_value(s)}})); auto ab = mm->add_instruction(migraphx::make_op("hip::allocate", {{"shape", to_value(s)}}));
auto sum = mm->add_instruction(migraphx::make_op("gpu::add"), gx, gx, ab); auto sum = mm->add_instruction(make_precompile_op("add"), gx, gx, ab);
auto r = mm->add_instruction(migraphx::make_op("hip::copy_from_gpu"), sum); auto r = mm->add_instruction(migraphx::make_op("hip::copy_from_gpu"), sum);
mm->add_return({r}); mm->add_return({r});
......
...@@ -307,12 +307,14 @@ TEST_CASE(compile_math) ...@@ -307,12 +307,14 @@ TEST_CASE(compile_math)
"erf(x)", "erf(x)",
"exp(x)", "exp(x)",
"floor(x)", "floor(x)",
"fmod(x, x)",
"isnan(x)", "isnan(x)",
"log(x)", "log(x)",
"max(x, x)", "max(x, x)",
"min(x, x)", "min(x, x)",
"pow(x, 0)", "pow(x, 0)",
"pow(x, x)", "pow(x, x)",
"remainder(x,x)",
"round(x)", "round(x)",
"rsqrt(x)", "rsqrt(x)",
"sin(x)", "sin(x)",
......
...@@ -21,63 +21,46 @@ ...@@ -21,63 +21,46 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/device/gelu.hpp> #ifndef MIGRAPHX_GUARD_TEST_GPU_MAKE_PRECOMPILE_OP_HPP
#include <migraphx/gpu/device/nary.hpp> #define MIGRAPHX_GUARD_TEST_GPU_MAKE_PRECOMPILE_OP_HPP
#include <migraphx/gpu/device/types.hpp>
#include <cmath>
namespace migraphx { #include <migraphx/operation.hpp>
inline namespace MIGRAPHX_INLINE_NS { #include <migraphx/gpu/compiler.hpp>
namespace gpu { #include <migraphx/make_op.hpp>
namespace device {
// x * 0.5 * (1.0 + erf(x / sqrt(2.0))) // NOLINTNEXTLINE
template <class T> #define MIGRAPHX_GPU_TEST_PRECOMPILE(...) \
auto gelu_fn(T x) __device__ struct test_compiler : migraphx::gpu::compiler<test_compiler> \
{ { \
return x * 0.5 * (1 + ::erf(x * M_SQRT1_2)); std::vector<std::string> names() const { return {__VA_ARGS__}; } \
} \
template <class... Ts> \
migraphx::operation compile_op(Ts&&...) const \
{ \
MIGRAPHX_THROW("Not compilable"); \
} \
\
template <class... Ts> \
migraphx::gpu::compiler_replace compile(Ts&&...) const \
{ \
MIGRAPHX_THROW("Not compilable"); \
} \
};
// 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * pow(x, 3)))) inline migraphx::operation make_precompile_op(migraphx::rank<0>, const migraphx::operation& op)
template <class T>
auto gelu_fn_new(T x) __device__
{ {
return 0.5 * x * (1 + tanh(sqrt(M_2_PI) * (x + 0.044715 * x * x * x))); return migraphx::make_op("gpu::precompile_op", {{"op", migraphx::to_value(op)}});
} }
void gelu(hipStream_t stream, const argument& result, const argument& arg) inline migraphx::operation make_precompile_op(migraphx::rank<1>, const std::string& name)
{ {
nary(stream, result, arg)([](auto x) __device__ { return gelu_fn(to_hip_type(x)); }); return make_precompile_op(migraphx::rank<0>{}, migraphx::make_op(name));
} }
void gelu_new(hipStream_t stream, const argument& result, const argument& arg) template <class T>
{ auto make_precompile_op(const T& x)
nary(stream, result, arg)([](auto x) __device__ { return gelu_fn_new(to_hip_type(x)); });
}
void add_gelu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ {
auto sum = to_hip_type(x + y);
return gelu_fn(sum);
});
}
void add_gelu_new(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return make_precompile_op(migraphx::rank<1>{}, x);
auto sum = to_hip_type(x + y);
return gelu_fn(sum);
});
} }
} // namespace device #endif // MIGRAPHX_GUARD_TEST_GPU_MAKE_PRECOMPILE_OP_HPP
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -37,10 +37,6 @@ ...@@ -37,10 +37,6 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <test.hpp> #include <test.hpp>
using migraphx::trim;
// m test_gpu_mlir && ./bin/test_gpu_mlir
struct mlir_gpu_target : migraphx::gpu::target struct mlir_gpu_target : migraphx::gpu::target
{ {
std::string name() const { return "mlir"; } std::string name() const { return "mlir"; }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/adjust_allocation.hpp> #include <migraphx/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
...@@ -38,10 +39,13 @@ ...@@ -38,10 +39,13 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
#include "make_precompile_op.hpp"
void run_passes(migraphx::module& m) // Treat some operators as compilable to enable lowering
MIGRAPHX_GPU_TEST_PRECOMPILE("add", "mul", "convert")
void run_passes(migraphx::module& m, migraphx::gpu::context& ctx)
{ {
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(m, migraphx::run_passes(m,
{migraphx::auto_contiguous{}, {migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false}, migraphx::gpu::lowering{&ctx, false},
...@@ -52,18 +56,6 @@ void run_passes(migraphx::module& m) ...@@ -52,18 +56,6 @@ void run_passes(migraphx::module& m)
migraphx::dead_code_elimination{}}); migraphx::dead_code_elimination{}});
} }
bool get_int8_x4_format()
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto ctx = migraphx::gpu::context{};
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
TEST_CASE(quant_dot) TEST_CASE(quant_dot)
{ {
auto create_module = [] { auto create_module = [] {
...@@ -102,11 +94,13 @@ TEST_CASE(quant_dot) ...@@ -102,11 +94,13 @@ TEST_CASE(quant_dot)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc); packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc);
} }
auto gemm = auto gemm = m.add_instruction(
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), migraphx::make_op("gpu::quant_gemm",
l1, {{"int8_x4_format", int8_x4},
packa, {"compute_fp32", migraphx::gpu::get_compute_fp32_flag()}}),
gemm_alloc); l1,
packa,
gemm_alloc);
auto beta_broadcast = m.add_instruction( auto beta_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta); migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta);
...@@ -116,19 +110,19 @@ TEST_CASE(quant_dot) ...@@ -116,19 +110,19 @@ TEST_CASE(quant_dot)
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc); m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction( auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto m3_beta = auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_contiguous, mul_alloc);
m.add_instruction(migraphx::make_op("gpu::mul"), l3, beta_contiguous, mul_alloc); auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output);
auto gemm_add = m.add_instruction(migraphx::make_op("gpu::add"), gemm, m3_beta, output);
m.add_return({gemm_add}); m.add_return({gemm_add});
return m; return m;
}; };
auto m1 = create_module(); auto m1 = create_module();
run_passes(m1); auto ctx = migraphx::gpu::context{};
run_passes(m1, ctx);
bool flag = get_int8_x4_format(); bool int8_x4 = migraphx::gpu::get_int8_x4_format(ctx);
auto m2 = create_optimized_int8_x4(flag); auto m2 = create_optimized_int8_x4(int8_x4);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -187,21 +181,23 @@ TEST_CASE(quant_dot_trans) ...@@ -187,21 +181,23 @@ TEST_CASE(quant_dot_trans)
// back result to int8 // back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op( auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}})); "hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert = m.add_instruction( auto tl1_convert =
migraphx::make_op("gpu::convert", {{"target_type", alpha->get_shape().type()}}), m.add_instruction(make_precompile_op(migraphx::make_op(
conta, "convert", {{"target_type", alpha->get_shape().type()}})),
tl1_convert_alloc); conta,
auto mul_alloc = m.add_instruction(migraphx::make_op( tl1_convert_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}})); "hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}}));
auto tl1_alpha_int32 = m.add_instruction( auto tl1_alpha_int32 =
migraphx::make_op("gpu::mul"), alpha_contiguous, tl1_convert, mul_alloc); m.add_instruction(make_precompile_op("mul"), alpha_contiguous, tl1_convert, mul_alloc);
// convert mul_res to int8 // convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op( auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}})); "hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}}));
auto tl1_alpha_int8 = m.add_instruction( auto tl1_alpha_int8 =
migraphx::make_op("gpu::convert", {{"target_type", conta->get_shape().type()}}), m.add_instruction(make_precompile_op(migraphx::make_op(
tl1_alpha_int32, "convert", {{"target_type", conta->get_shape().type()}})),
tl1_alpha_int8_alloc); tl1_alpha_int32,
tl1_alpha_int8_alloc);
auto packb = contb; auto packb = contb;
if(int8_x4) if(int8_x4)
...@@ -211,21 +207,24 @@ TEST_CASE(quant_dot_trans) ...@@ -211,21 +207,24 @@ TEST_CASE(quant_dot_trans)
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb); packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb);
} }
auto gemm = auto gemm = m.add_instruction(
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), migraphx::make_op("gpu::quant_gemm",
tl1_alpha_int8, {{"int8_x4_format", int8_x4},
packb, {"compute_fp32", migraphx::gpu::get_compute_fp32_flag()}}),
output); tl1_alpha_int8,
packb,
output);
m.add_return({gemm}); m.add_return({gemm});
return m; return m;
}; };
auto m1 = create_module(); auto m1 = create_module();
bool flag = get_int8_x4_format(); auto ctx = migraphx::gpu::context{};
auto m2 = create_optimized_int8_x4(flag); run_passes(m1, ctx);
run_passes(m1); bool int8_x4 = migraphx::gpu::get_int8_x4_format(ctx);
auto m2 = create_optimized_int8_x4(int8_x4);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -292,11 +291,13 @@ TEST_CASE(quant_dot_pad) ...@@ -292,11 +291,13 @@ TEST_CASE(quant_dot_pad)
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc); packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc);
} }
auto gemm = auto gemm = m.add_instruction(
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), migraphx::make_op("gpu::quant_gemm",
pl1, {{"int8_x4_format", int8_x4},
packa, {"compute_fp32", migraphx::gpu::get_compute_fp32_flag()}}),
gemm_alloc); pl1,
packa,
gemm_alloc);
auto beta_broadcast = auto beta_broadcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta);
...@@ -306,18 +307,18 @@ TEST_CASE(quant_dot_pad) ...@@ -306,18 +307,18 @@ TEST_CASE(quant_dot_pad)
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc); m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction( auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto m3_beta = auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_contiguous, mul_alloc);
m.add_instruction(migraphx::make_op("gpu::mul"), l3, beta_contiguous, mul_alloc); auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output);
auto gemm_add = m.add_instruction(migraphx::make_op("gpu::add"), gemm, m3_beta, output);
m.add_return({gemm_add}); m.add_return({gemm_add});
return m; return m;
}; };
auto m1 = create_module(); auto m1 = create_module();
bool flag = get_int8_x4_format(); auto ctx = migraphx::gpu::context{};
auto m2 = create_optimized_int8_x4(flag); run_passes(m1, ctx);
run_passes(m1); bool int8_x4 = migraphx::gpu::get_int8_x4_format(ctx);
auto m2 = create_optimized_int8_x4(int8_x4);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -396,14 +397,15 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -396,14 +397,15 @@ TEST_CASE(quant_dot_trans_pad)
// back result to int8 // back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op( auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}})); "hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert = m.add_instruction( auto tl1_convert =
migraphx::make_op("gpu::convert", {{"target_type", alpha->get_shape().type()}}), m.add_instruction(make_precompile_op(migraphx::make_op(
conta, "convert", {{"target_type", alpha->get_shape().type()}})),
tl1_convert_alloc); conta,
auto mul_alloc = m.add_instruction(migraphx::make_op( tl1_convert_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}})); "hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}}));
auto tl1_alpha_int32 = m.add_instruction( auto tl1_alpha_int32 =
migraphx::make_op("gpu::mul"), alpha_contiguous, tl1_convert, mul_alloc); m.add_instruction(make_precompile_op("mul"), alpha_contiguous, tl1_convert, mul_alloc);
// convert mul_res to int8 // convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op( auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}})); "hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}}));
...@@ -415,10 +417,11 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -415,10 +417,11 @@ TEST_CASE(quant_dot_trans_pad)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
} }
auto tl1_alpha_int8 = m.add_instruction( auto tl1_alpha_int8 =
migraphx::make_op("gpu::convert", {{"target_type", conta->get_shape().type()}}), m.add_instruction(make_precompile_op(migraphx::make_op(
tl1_alpha_int32, "convert", {{"target_type", conta->get_shape().type()}})),
tl1_alpha_int8_alloc); tl1_alpha_int32,
tl1_alpha_int8_alloc);
auto pa = tl1_alpha_int8; auto pa = tl1_alpha_int8;
if(int8_x4) if(int8_x4)
...@@ -438,17 +441,23 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -438,17 +441,23 @@ TEST_CASE(quant_dot_trans_pad)
} }
auto gemm = m.add_instruction( auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), pa, packb, output); migraphx::make_op("gpu::quant_gemm",
{{"int8_x4_format", int8_x4},
{"compute_fp32", migraphx::gpu::get_compute_fp32_flag()}}),
pa,
packb,
output);
m.add_return({gemm}); m.add_return({gemm});
return m; return m;
}; };
auto m1 = create_module(); auto m1 = create_module();
bool flag = get_int8_x4_format(); auto ctx = migraphx::gpu::context{};
auto m2 = create_optimized_int8_x4(flag); run_passes(m1, ctx);
run_passes(m1); bool int8_x4 = migraphx::gpu::get_int8_x4_format(ctx);
auto m2 = create_optimized_int8_x4(int8_x4);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
......
...@@ -724,7 +724,7 @@ TEST_CASE(test39) ...@@ -724,7 +724,7 @@ TEST_CASE(test39)
auto sub_modules = p.get_modules(); auto sub_modules = p.get_modules();
std::reverse(sub_modules.begin(), sub_modules.end()); std::reverse(sub_modules.begin(), sub_modules.end());
for(auto& smod : sub_modules) for(const auto& smod : sub_modules)
{ {
run_pass(*smod); run_pass(*smod);
} }
......
batch_norm_1d_test:
7
x
scale
bias
mean
variancey"BatchNormalizationbatch_norm_1d_testZ
x




Z
scale

Z
bias

Z
mean

Z
variance

b
y




B
\ No newline at end of file
batch_norm_2d_test:
7
x
scale
bias
mean
variancey"BatchNormalizationbatch_norm_2d_testZ
x




Z
scale

Z
bias

Z
mean

Z
variance

b
y




B
\ No newline at end of file
batch_norm_3d_test:
J
x
scale
bias
mean
variancey"BatchNormalization*
epsilon75batch_norm_3d_testZ
x






Z
scale


Z
bias


Z
mean


Z
variance


b
y






B
\ No newline at end of file
batch_norm_flat_test:
J
x
scale
bias
mean
variancey"BatchNormalization*
epsilon75batch_norm_flat_testZ
x


Z
scale

Z
bias

Z
mean

Z
variance

b
y


B
\ No newline at end of file
!batch_norm_invalid_bias_rank_test:
7
x
scale
bias
mean
variancey"BatchNormalization!batch_norm_invalid_bias_rank_testZ
x




Z
scale

Z
bias


Z
mean

Z
variance

b
y




B
\ No newline at end of file
batch_norm_invalid_rank_test:
7
x
scale
bias
mean
variancey"BatchNormalizationbatch_norm_invalid_rank_testZ
x


Z
scale

Z
bias

Z
mean

Z
variance

b
y


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment