Unverified Commit b031c76c authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into ci_60

parents c8aa00bf aac4e950
...@@ -89,7 +89,7 @@ requests==2.28.2 ...@@ -89,7 +89,7 @@ requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.30.0 rocm-docs-core==0.30.1
# via -r requirements.in # via -r requirements.in
smmap==5.0.0 smmap==5.0.0
# via gitdb # via gitdb
......
...@@ -23,10 +23,9 @@ ...@@ -23,10 +23,9 @@
##################################################################################### #####################################################################################
google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0 nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.6.0 ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@a6880f1e6daec99876cd6a4820fbc69c57216401 -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@ee3ca1eff247c287e855bc8588e9384b0e17abcf -DBUILD_FAT_LIBROCKCOMPILER=On
...@@ -32,8 +32,8 @@ ...@@ -32,8 +32,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T, class F> template <class T, class U, class F>
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta) void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta)
{ {
std::size_t n_dims = cmat.get_shape().lens().size(); std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2; std::size_t dim_0 = n_dims - 2;
...@@ -52,7 +52,8 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha ...@@ -52,7 +52,8 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
double s = 0.0; double s = 0.0;
dfor(k)([&](auto kk) { dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk; a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); s += static_cast<double>(amat(a_idx.begin(), a_idx.end())) *
static_cast<double>(bmat(b_idx.begin(), b_idx.end()));
}); });
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
}); });
......
...@@ -44,9 +44,11 @@ struct quant_dot ...@@ -44,9 +44,11 @@ struct quant_dot
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(t != shape::int8_type) std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
{ {
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t"); MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type");
} }
if(not std::all_of( if(not std::all_of(
...@@ -73,6 +75,10 @@ struct quant_dot ...@@ -73,6 +75,10 @@ struct quant_dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens}; return {shape::int32_type, out_lens};
} }
}; };
......
...@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts) ...@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts)
smod->finalize(contexts); smod->finalize(contexts);
} }
} }
#ifndef BUILD_DEV
if(std::any_of(this->begin(), this->end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
}
#endif
// Warn when an instruction is not normalized // Warn when an instruction is not normalized
auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); }); auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); });
......
...@@ -183,6 +183,11 @@ struct find_nested_convert ...@@ -183,6 +183,11 @@ struct find_nested_convert
auto x = ins->inputs().front(); auto x = ins->inputs().front();
auto input = x->inputs().front(); auto input = x->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(ins->get_shape() != input->get_shape()) if(ins->get_shape() != input->get_shape())
return; return;
......
...@@ -195,7 +195,7 @@ struct gemm_impl ...@@ -195,7 +195,7 @@ struct gemm_impl
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc; ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type()); arg_type = get_type(input_shapes[0].type());
output_type = arg_type; output_type = get_type(input_shapes[2].type());
if(output_type == rocblas_datatype_i8_r) if(output_type == rocblas_datatype_i8_r)
{ {
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
......
...@@ -112,7 +112,7 @@ struct rocblas_gemm ...@@ -112,7 +112,7 @@ struct rocblas_gemm
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type)
{ {
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx); gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
} }
......
...@@ -110,6 +110,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -110,6 +110,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
if(not gpu::rocblas_fp8_available()) if(not gpu::rocblas_fp8_available())
{ {
unsupported_fp8_ops.insert("dot"); unsupported_fp8_ops.insert("dot");
unsupported_fp8_ops.insert("quant_dot");
} }
// MIOpen doesn't have support for fp8 pooling yet. // MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling"); unsupported_fp8_ops.insert("pooling");
......
...@@ -25,18 +25,13 @@ ...@@ -25,18 +25,13 @@
add_library(migraphx_ref add_library(migraphx_ref
target.cpp target.cpp
lowering.cpp lowering.cpp
gemm.cpp
) )
set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref) set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref)
rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h)
rocm_clang_tidy_check(migraphx_ref) rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref PRIVATE Threads::Threads) target_link_libraries(migraphx_ref PRIVATE Threads::Threads)
target_link_libraries(migraphx_ref PUBLIC migraphx) target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref SYSTEM PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
migraphx_generate_export_header(migraphx_ref) migraphx_generate_export_header(migraphx_ref)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/par_for.hpp>
#include <blaze/math/CustomMatrix.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace ref {
template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT
template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
// assert(s.lens().size() == 2);
std::size_t n_dims = s.lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
if(s.transposed())
return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
}
template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
auto mat = make_mat(x);
if(x.get_shape().transposed())
f(blaze::trans(mat));
else
f(mat);
}
template <class T>
struct is_fast_gemm_type : std::false_type
{
};
template <>
struct is_fast_gemm_type<float> : std::true_type
{
};
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::true_type)
{
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
c = beta * c;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if(alpha != 0.0)
{
c = c + alpha * a * b;
}
});
});
}
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}
template <class T, class F>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
auto lens = amat.get_shape().lens();
bool batch_mul =
std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
if(batch_mul)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
else
{
migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{});
}
}
template <class F>
void migemm_tpl(
const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{
visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
}
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
} // namespace ref
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace ref {
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta);
} // namespace ref
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -44,7 +44,6 @@ ...@@ -44,7 +44,6 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp> #include <migraphx/clamp.hpp>
#include <migraphx/ref/gemm.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
...@@ -283,8 +282,8 @@ struct ref_gemm ...@@ -283,8 +282,8 @@ struct ref_gemm
argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
migemm(result, args[0], args[1], 1.0f, 0.0f); visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result; return result;
} }
}; };
...@@ -306,24 +305,14 @@ struct ref_quant_gemm ...@@ -306,24 +305,14 @@ struct ref_quant_gemm
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// first, convert the args[0] and args[1] from int8_t to int32_t result.visit([&](auto cmat) {
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}}; visit_all(args.at(0), args.at(1))(
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}}; [&](auto amat, auto bmat) { return gemm(cmat, amat, bmat, 1.0f, 0.0f); });
arg_0.visit([&](auto output) {
args.at(0).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
arg_1.visit([&](auto output) {
args.at(1).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
}); });
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
return result; return result;
} }
}; };
MIGRAPHX_REGISTER_OP(ref_gemm) MIGRAPHX_REGISTER_OP(ref_gemm)
template <class Op> template <class Op>
......
...@@ -24,19 +24,23 @@ ...@@ -24,19 +24,23 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> template <typename DType, typename CType>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}}; auto dtype = migraphx::shape::get_type<DType>{};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; migraphx::shape m1_shape{dtype, {3, 2, 8, 2}};
migraphx::shape m2_shape{dtype, {3, 2, 7, 8}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction( auto tl1 = mm->add_instruction(
...@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> ...@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); migraphx::add_apply_alpha_beta(
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p; return p;
} }
}; };
template struct batch_quant_dot_1<int8_t, int32_t>;
template struct batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -25,23 +25,31 @@ ...@@ -25,23 +25,31 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2> template <typename DType, typename CType>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}}; auto dtype = migraphx::shape::get_type<DType>{};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}}; auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
migraphx::shape m1_shape{dtype, {3, 2, 2, 8}};
migraphx::shape m2_shape{dtype, {3, 2, 8, 7}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p; return p;
} }
}; };
template struct batch_quant_dot_2<int8_t, int32_t>;
template struct batch_quant_dot_2<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> template <migraphx::shape::type_t DType>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}}; migraphx::shape m1_shape{DType, {3, 2, 2, 6}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}}; migraphx::shape m2_shape{DType, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> ...@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return p; return p;
} }
}; };
template struct batch_quant_dot_3<migraphx::shape::int8_type>;
template struct batch_quant_dot_3<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> template <migraphx::shape::type_t DType>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}}; migraphx::shape m1_shape{DType, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}}; migraphx::shape m2_shape{DType, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> ...@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return p; return p;
} }
}; };
template struct batch_quant_dot_4<migraphx::shape::int8_type>;
template struct batch_quant_dot_4<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> template <migraphx::shape::type_t DType>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}}; migraphx::shape m1_shape{DType, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}}; migraphx::shape m2_shape{DType, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> ...@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return p; return p;
} }
}; };
template struct batch_quant_dot_5<migraphx::shape::int8_type>;
template struct batch_quant_dot_5<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -67,15 +67,27 @@ int main(int argc, const char* argv[]) ...@@ -67,15 +67,27 @@ int main(int argc, const char* argv[])
{ {
run_verify rv; run_verify rv;
rv.add_validation_for("gpu", &validate_gpu); rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu",
{"test_if_lp", rv.disable_test_for("cpu", {
"test_if_param", "test_if_lp", "test_if_param", "test_if_literal", "test_select_module_add",
"test_if_literal", "test_select_module_reduce", "test_select_module_conv", "test_split_single_dyn_dim",
"test_select_module_add", "test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_select_module_reduce", "test_instancenorm_large_3d<migraphx::shape::half_type>",
"test_select_module_conv", // these tests are disabled due issue of lossy downcast, see issue#2517
"test_split_single_dyn_dim", #if defined(__GNUC__) and !defined(__clang__)
"test_instancenorm_large_3d<migraphx::shape::float_type>", "batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"}); "quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
"quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
#else
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"
#endif
});
rv.disable_test_for("gpu",
{// These passes on MI300 but fails on others, same issue as CPU.
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -25,23 +25,31 @@ ...@@ -25,23 +25,31 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1> template <typename DType, typename CType>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{1});
return p; return p;
} }
}; };
template struct quant_dot_3args_1<int8_t, int32_t>;
template struct quant_dot_3args_1<migraphx::fp8::fp8e4m3fnuz, float>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment