Commit 050184cb authored by Umang Yadav's avatar Umang Yadav
Browse files

revert some changes

parent 3f213325
......@@ -44,10 +44,9 @@ struct quant_dot
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type, shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type");
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
}
if(not std::all_of(
......@@ -74,10 +73,6 @@ struct quant_dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens};
}
};
......
......@@ -183,11 +183,6 @@ struct find_nested_convert
auto x = ins->inputs().front();
auto input = x->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(ins->get_shape() != input->get_shape())
return;
......
......@@ -69,8 +69,7 @@ struct ck_gemm
static bool is_ck_supported_type(shape::type_t t)
{
return contains(
{shape::half_type, shape::int8_type, shape::int32_type, shape::fp8e4m3fnuz_type}, t);
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
......
......@@ -180,9 +180,12 @@ struct gemm_impl
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type());
output_type = get_type(input_shapes[2].type());
compute_type =
output_type; // not valid for ex3 BETA APIs. it has different type and set differently.
output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
......
......@@ -112,7 +112,7 @@ struct rocblas_gemm
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type)
if(this->name() == "gpu::gemm")
{
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
......
......@@ -24,7 +24,6 @@
#include <migraphx/ref/lowering.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
......@@ -308,46 +307,19 @@ struct ref_quant_gemm
{
argument result{output_shape};
// first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{output_shape.type(), {args.at(0).get_shape().lens()}}};
argument arg_1{{output_shape.type(), {args.at(1).get_shape().lens()}}};
if(output_shape.type() == migraphx::shape::float_type)
{
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<float>(x);
});
});
});
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) {
args.at(0).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<float>(x);
});
});
});
migemm(result, arg_0, arg_1, 1.0f, 0.0f);
}
else if(output_shape.type() == migraphx::shape::int32_type)
{
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<int32_t>(x);
});
});
});
arg_1.visit([&](auto output) {
args.at(1).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<int32_t>(x);
});
});
});
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
}
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
return result;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment