Commit ab768083 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

first implementation of calling GPU int8 gemm correctly.

parent a9469403
......@@ -24,7 +24,7 @@ struct quant_dot
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
return pack(f(as_number(self.alpha), "alpha"), f(as_number(self.beta), "beta"));
}
std::string name() const { return "quant_dot"; }
......@@ -60,11 +60,8 @@ struct quant_dot
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
// all dims need to be multiple of 4
auto m = a.lens()[dim_0];
auto n = b.lens()[dim_1];
auto k = a.lens()[dim_1];
if((m % 4) != 0 or (n % 4) != 0 or (k % 4) != 0)
// k be multiple of 4
if((a.lens()[dim_1] % 4) != 0)
{
MIGRAPHX_THROW("QUANT_DOT: size of A {" + to_string_range(a.lens()) + "} and B {" +
to_string_range(b.lens()) + "} must be multiple of 4 for int8 type");
......
......@@ -18,7 +18,7 @@ T as_number(T x)
return x;
}
inline int32_t as_number(int8_t x) { return static_cast<int32_t>(x); }
inline uint32_t as_number(uint8_t x) { return static_cast<uint8_t>(x); }
inline uint32_t as_number(uint8_t x) { return static_cast<uint32_t>(x); }
template <class T>
struct tensor_view
......
......@@ -69,6 +69,7 @@ add_library(migraphx_gpu
lrn.cpp
schedule_model.cpp
adjust_allocation.cpp
pack_int8_args.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
......@@ -15,7 +15,7 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
{
auto output_shape = result.get_shape();
auto dim_0 = output_shape.lens().size() - 2;
std::size_t ldb = output_shape.strides()[dim_0];
std::size_t lda = output_shape.strides()[dim_0];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
auto* out_ptr = device_cast(output.data());
......@@ -25,9 +25,9 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_m = idx[0];
std::size_t i_k = idx[1];
out_ptr[i_k % nb + (i_m + (i_k / nb) * ldb) * nb] = in_ptr[i_m + i_k * ldb];
std::size_t i_m = idx[1];
std::size_t i_k = idx[0];
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb] = in_ptr[i_m + i_k * lda];
});
});
});
......@@ -37,7 +37,7 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
{
auto output_shape = result.get_shape();
auto dim_1 = output_shape.lens().size() - 1;
std::size_t lda = output_shape.strides()[dim_1];
std::size_t ldb = output_shape.strides()[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
auto* out_ptr = device_cast(output.data());
......@@ -47,9 +47,9 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_n = idx[0];
std::size_t i_k = idx[1];
out_ptr[i_k % nb + (i_n + (i_k / nb) * lda) * nb] = in_ptr[i_n + i_k * lda];
std::size_t i_n = idx[1];
std::size_t i_k = idx[0];
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb] = in_ptr[i_n + i_k * ldb];
});
});
});
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#define MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
/**
*/
struct pack_int8_args
{
std::string name() const { return "pack_int8_args"; }
void apply(program& p) const;
};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -97,6 +97,7 @@ struct miopen_apply
add_generic_op<hip_min>("min");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax");
......@@ -109,7 +110,6 @@ struct miopen_apply
add_quant_convolution_op();
add_pooling_op();
add_batch_norm_inference_op();
add_quant_gemm_op();
}
void apply()
......@@ -263,35 +263,6 @@ struct miopen_apply
output);
});
}
void add_quant_gemm_op()
{
apply_map.emplace("quant_gemm", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output);
// Need another two buffers for packed data buffer
auto shape_a = refs.at(0)->get_shape();
if(shape_a.transposed())
{
auto pack_a = insert_allocation(ins, shape_a);
refs.push_back(pack_a);
std::swap(refs.back(), refs.at(0));
}
auto shape_b = refs.at(1)->get_shape();
if(!shape_b.transposed())
{
auto pack_b = insert_allocation(ins, shape_b);
refs.push_back(pack_b);
std::swap(refs.back(), refs.at(1));
}
return prog->replace_instruction(ins, miopen_quant_gemm{op}, refs);
});
}
};
void lowering::apply(program& p) const { miopen_apply{&p, ctx}.apply(); }
......
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/gpu/hip.hpp>
#include <algorithm>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void pack_int8_args::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->name() != "gpu::quant_gemm")
continue;
auto inputs = ins->inputs();
auto shape_a = inputs.at(0)->get_shape();
if (shape_a.type() != shape::int8_type)
continue;
if (shape_a.transposed())
{
auto pack_a = p.insert_instruction(ins, hip_allocate{shape_a});
inputs.push_back(pack_a);
swap(inputs.at(0), inputs.back());
}
auto shape_b = inputs.at(1)->get_shape();
if (!shape_b.transposed())
{
auto pack_b = p.insert_instruction(ins, hip_allocate{shape_b});
inputs.push_back(pack_b);
swap(inputs.at(1), inputs.back());
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), inputs);
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -53,7 +53,17 @@ rb_type<T>* to_rocblas_type(T* x)
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
std::vector<shape> input_shapes(inputs);
if (!inputs.at(1).transposed())
{
input_shapes.pop_back();
}
if (inputs.at(0).transposed())
{
input_shapes.pop_back();
}
input_shapes.pop_back();
check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
}
......
......@@ -21,6 +21,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
......@@ -70,6 +71,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_allocation{"hip::allocate"},
check_context<context>{},
dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
eliminate_identity{}
};
// clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment