"...resnet50_tensorflow.git" did not exist on "c430556b2a2f3dd8301c18e58b0c186749e24c49"
Unverified Commit aac4e950 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

FP8 QuantDot operation (#2506)

parent 9d2003a2
......@@ -23,7 +23,6 @@
#####################################################################################
google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
......
......@@ -32,8 +32,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T, class F>
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
template <class T, class U, class F>
void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
......@@ -52,7 +52,8 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
s += static_cast<double>(amat(a_idx.begin(), a_idx.end())) *
static_cast<double>(bmat(b_idx.begin(), b_idx.end()));
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
......
......@@ -44,9 +44,11 @@ struct quant_dot
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(t != shape::int8_type)
std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type");
}
if(not std::all_of(
......@@ -73,6 +75,10 @@ struct quant_dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens};
}
};
......
......@@ -183,6 +183,11 @@ struct find_nested_convert
auto x = ins->inputs().front();
auto input = x->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(ins->get_shape() != input->get_shape())
return;
......
......@@ -195,7 +195,7 @@ struct gemm_impl
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
output_type = get_type(input_shapes[2].type());
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
......
......@@ -112,7 +112,7 @@ struct rocblas_gemm
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
if(this->name() == "gpu::gemm")
if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type)
{
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
......
......@@ -110,6 +110,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
unsupported_fp8_ops.insert("quant_dot");
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
......
......@@ -25,18 +25,13 @@
add_library(migraphx_ref
target.cpp
lowering.cpp
gemm.cpp
)
set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref)
rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h)
rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref PRIVATE Threads::Threads)
target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref SYSTEM PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
migraphx_generate_export_header(migraphx_ref)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/par_for.hpp>
#include <blaze/math/CustomMatrix.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace ref {
template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT
template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
// assert(s.lens().size() == 2);
std::size_t n_dims = s.lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
if(s.transposed())
return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
}
template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
auto mat = make_mat(x);
if(x.get_shape().transposed())
f(blaze::trans(mat));
else
f(mat);
}
template <class T>
struct is_fast_gemm_type : std::false_type
{
};
template <>
struct is_fast_gemm_type<float> : std::true_type
{
};
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::true_type)
{
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
c = beta * c;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if(alpha != 0.0)
{
c = c + alpha * a * b;
}
});
});
}
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}
template <class T, class F>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
auto lens = amat.get_shape().lens();
bool batch_mul =
std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
if(batch_mul)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
else
{
migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{});
}
}
template <class F>
void migemm_tpl(
const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{
visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
}
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
} // namespace ref
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace ref {
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta);
} // namespace ref
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -44,7 +44,6 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/ref/gemm.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
......@@ -283,8 +282,8 @@ struct ref_gemm
argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
migemm(result, args[0], args[1], 1.0f, 0.0f);
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
}
};
......@@ -306,24 +305,14 @@ struct ref_quant_gemm
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) {
args.at(0).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
arg_1.visit([&](auto output) {
args.at(1).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
result.visit([&](auto cmat) {
visit_all(args.at(0), args.at(1))(
[&](auto amat, auto bmat) { return gemm(cmat, amat, bmat, 1.0f, 0.0f); });
});
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
return result;
}
};
MIGRAPHX_REGISTER_OP(ref_gemm)
template <class Op>
......
......@@ -24,19 +24,23 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
template <typename DType, typename CType>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto dtype = migraphx::shape::get_type<DType>{};
auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m1_shape{dtype, {3, 2, 8, 2}};
migraphx::shape m2_shape{dtype, {3, 2, 7, 8}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(
......@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
migraphx::add_apply_alpha_beta(
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p;
}
};
template struct batch_quant_dot_1<int8_t, int32_t>;
template struct batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
template <typename DType, typename CType>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto dtype = migraphx::shape::get_type<DType>{};
auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m1_shape{dtype, {3, 2, 2, 8}};
migraphx::shape m2_shape{dtype, {3, 2, 8, 7}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p;
}
};
template struct batch_quant_dot_2<int8_t, int32_t>;
template struct batch_quant_dot_2<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
template <migraphx::shape::type_t DType>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}};
migraphx::shape m1_shape{DType, {3, 2, 2, 6}};
migraphx::shape m2_shape{DType, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
......@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return p;
}
};
template struct batch_quant_dot_3<migraphx::shape::int8_type>;
template struct batch_quant_dot_3<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
template <migraphx::shape::type_t DType>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}};
migraphx::shape m1_shape{DType, {2, 4, 6, 3}};
migraphx::shape m2_shape{DType, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
......@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return p;
}
};
template struct batch_quant_dot_4<migraphx::shape::int8_type>;
template struct batch_quant_dot_4<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
template <migraphx::shape::type_t DType>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}};
migraphx::shape m1_shape{DType, {3, 2, 7, 2}};
migraphx::shape m2_shape{DType, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
......@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return p;
}
};
template struct batch_quant_dot_5<migraphx::shape::int8_type>;
template struct batch_quant_dot_5<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -67,15 +67,27 @@ int main(int argc, const char* argv[])
{
run_verify rv;
rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu",
{"test_if_lp",
"test_if_param",
"test_if_literal",
"test_select_module_add",
"test_select_module_reduce",
"test_select_module_conv",
"test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("cpu", {
"test_if_lp", "test_if_param", "test_if_literal", "test_select_module_add",
"test_select_module_reduce", "test_select_module_conv", "test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>",
// these tests are disabled due issue of lossy downcast, see issue#2517
#if defined(__GNUC__) and !defined(__clang__)
"batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
"quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
"quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
#else
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"
#endif
});
rv.disable_test_for("gpu",
{// These passes on MI300 but fails on others, same issue as CPU.
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"});
rv.run(argc, argv);
}
......@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
template <typename DType, typename CType>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto ctype = migraphx::shape::get_type<CType>();
auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{1});
return p;
}
};
template struct quant_dot_3args_1<int8_t, int32_t>;
template struct quant_dot_3args_1<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -28,22 +28,29 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
template <typename DType, typename CType>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto ctype = migraphx::shape::get_type<CType>();
auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m1_shape{dtype, {8, 2}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::add_apply_alpha_beta(
*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p;
}
};
template struct quant_dot_3args_2<int8_t, int32_t>;
template struct quant_dot_3args_2<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -28,22 +28,28 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
template <typename DType, typename CType>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto ctype = migraphx::shape::get_type<CType>();
auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
migraphx::add_apply_alpha_beta(
*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), CType{2}, CType{3});
return p;
}
};
template struct quant_dot_3args_3<int8_t, int32_t>;
template struct quant_dot_3args_3<migraphx::fp8::fp8e4m3fnuz, float>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment