Commit 050184cb authored by Umang Yadav's avatar Umang Yadav
Browse files

revert some changes

parent 3f213325
...@@ -44,10 +44,9 @@ struct quant_dot ...@@ -44,10 +44,9 @@ struct quant_dot
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type, shape::fp8e4m3fnuz_type}; if(t != shape::int8_type)
if(not contains(suppported_types, t))
{ {
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type"); MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
} }
if(not std::all_of( if(not std::all_of(
...@@ -74,10 +73,6 @@ struct quant_dot ...@@ -74,10 +73,6 @@ struct quant_dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens}; return {shape::int32_type, out_lens};
} }
}; };
......
...@@ -183,11 +183,6 @@ struct find_nested_convert ...@@ -183,11 +183,6 @@ struct find_nested_convert
auto x = ins->inputs().front(); auto x = ins->inputs().front();
auto input = x->inputs().front(); auto input = x->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(ins->get_shape() != input->get_shape()) if(ins->get_shape() != input->get_shape())
return; return;
......
...@@ -69,8 +69,7 @@ struct ck_gemm ...@@ -69,8 +69,7 @@ struct ck_gemm
static bool is_ck_supported_type(shape::type_t t) static bool is_ck_supported_type(shape::type_t t)
{ {
return contains( return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
{shape::half_type, shape::int8_type, shape::int32_type, shape::fp8e4m3fnuz_type}, t);
} }
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
......
...@@ -180,9 +180,12 @@ struct gemm_impl ...@@ -180,9 +180,12 @@ struct gemm_impl
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc; ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type()); arg_type = get_type(input_shapes[0].type());
output_type = get_type(input_shapes[2].type()); output_type = arg_type;
compute_type = if(output_type == rocblas_datatype_i8_r)
output_type; // not valid for ex3 BETA APIs. it has different type and set differently. {
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
if(compute_fp32) if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
......
...@@ -112,7 +112,7 @@ struct rocblas_gemm ...@@ -112,7 +112,7 @@ struct rocblas_gemm
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type) if(this->name() == "gpu::gemm")
{ {
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx); gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
} }
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/ref/lowering.hpp> #include <migraphx/ref/lowering.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
...@@ -308,46 +307,19 @@ struct ref_quant_gemm ...@@ -308,46 +307,19 @@ struct ref_quant_gemm
{ {
argument result{output_shape}; argument result{output_shape};
// first, convert the args[0] and args[1] from int8_t to int32_t // first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{output_shape.type(), {args.at(0).get_shape().lens()}}}; argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{output_shape.type(), {args.at(1).get_shape().lens()}}}; argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
if(output_shape.type() == migraphx::shape::float_type)
{
arg_0.visit([&](auto output) { arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) { args.at(0).visit(
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
return static_cast<float>(x);
});
});
}); });
arg_1.visit([&](auto output) { arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) { args.at(1).visit(
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
return static_cast<float>(x);
});
});
});
migemm(result, arg_0, arg_1, 1.0f, 0.0f);
}
else if(output_shape.type() == migraphx::shape::int32_type)
{
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<int32_t>(x);
});
});
}); });
arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<int32_t>(x);
});
});
});
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0}); migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
}
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment