Unverified Commit 3d24a21c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add a pass to remove unsupported data types (#738)



* Add eliminate_data_type pass

* Formatting

* Auto convert quant ops

* Formatting

* Flip the order of decompose

* Compute max size differently

* Formatting

* Clamp values in convert

* Formatting

* Fix loss of precision in reduce

* Formatting

* Fix bugs in reduction

* Fix accumulator type in reference softmax implementation

* Formatting

* Update convert test

* Remove unused variables

* Remove unnecessary quant_dot check

* Formatting

* Add tests

* Formatting

* Remove unused code

* Remove duplicate ops

* Remove blaze dependency

* Use set since shape::type_t is no hashable on gcc 5

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 5d0ca2a6
......@@ -10,7 +10,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
#ifdef USE_DNNL
struct context
{
void finish() const {}
......@@ -27,30 +26,6 @@ struct context
this->bulk_execute(n, 256, f);
}
};
#else
struct context
{
void finish() const {}
template <class F>
void bulk_execute(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize =
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain);
std::size_t grainsize = std::ceil(static_cast<double>(n) / threadsize);
par_for(threadsize, 1, [&](auto tid) {
std::size_t work = tid * grainsize;
f(work, std::min(n, work + grainsize));
});
}
template <class F>
void bulk_execute(std::size_t n, F f)
{
this->bulk_execute(n, 256, f);
}
};
#endif
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -7,7 +7,6 @@
#include <migraphx/register_op.hpp>
#include <migraphx/check_shapes.hpp>
#include <unordered_map>
#ifdef USE_DNNL
#include <dnnl.hpp>
#include <migraphx/errors.hpp>
......@@ -132,6 +131,7 @@ struct dnnl_op : auto_register_op<Derived>
std::unordered_map<int, dnnl::memory::desc> result;
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = self.arg_map(inputs.size());
assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++)
{
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i));
......@@ -166,6 +166,7 @@ struct dnnl_op : auto_register_op<Derived>
// Compensate for allocation
inputs.pop_back();
const auto& self = static_cast<const Derived&>(*this);
auto name = self.name();
auto md = to_memory_desc(output_shape, inputs);
auto prim = get_primitive(md);
auto arg_lookup = self.arg_map(inputs.size());
......@@ -177,9 +178,13 @@ struct dnnl_op : auto_register_op<Derived>
auto debug_md = to_memory_desc(output_shape, to_shapes(debug_args));
for(auto&& p : debug_md)
{
if(md.count(p.first) == 0)
MIGRAPHX_THROW(name +
": Missing memory descriptor for: " + std::to_string(p.first));
if(p.second == md.at(p.first))
continue;
MIGRAPHX_THROW("Memory descriptor has changed for: " + std::to_string(p.first));
MIGRAPHX_THROW(name +
": Memory descriptor has changed for: " + std::to_string(p.first));
}
#endif
std::unordered_map<int, dnnl::memory> m;
......@@ -221,5 +226,3 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
} // namespace migraphx
#endif
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta);
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -2,15 +2,12 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP
#include <migraphx/config.hpp>
#if USE_DNNL
#include <omp.h>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
#if USE_DNNL
template <class F>
void parallel_for_impl(std::size_t n, std::size_t threadsize, F f)
{
......@@ -43,7 +40,6 @@ void parallel_for(std::size_t n, F f)
const int min_grain = 8;
parallel_for(n, min_grain, f);
}
#endif
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -24,7 +24,6 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/cpu/migemm.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp>
......@@ -564,37 +563,15 @@ struct cpu_apply
});
}
void extend_dnnl_extend_op(const std::string& op_name,
const std::string& cpu_name,
const std::string& dnnl_name)
{
apply_map.emplace(op_name, [=](instruction_ref ins) {
auto&& op = ins->get_operator();
if(has_op(dnnl_name) and ins->get_shape().type() == shape::type_t::float_type)
return replace(ins, make_op(dnnl_name, op.to_value()));
return replace(ins, make_op(cpu_name, op.to_value()));
});
}
void extend_dnnl_extend_op(const std::string& op_name, const std::string& dnnl_name)
{
apply_map.emplace(op_name, [=](instruction_ref ins) {
auto&& op = ins->get_operator();
if(has_op(dnnl_name) and ins->get_shape().type() == shape::type_t::float_type)
return replace(ins, make_op(dnnl_name, op.to_value()));
return ins;
});
}
void init()
{
create_output_names();
extend_dnnl_extend_op("add", "cpu::add", "dnnl::add");
extend_dnnl_extend_op("mul", "cpu::mul", "dnnl::mul");
extend_dnnl_extend_op("convolution", "cpu::convolution", "dnnl::convolution");
extend_dnnl_extend_op("dot", "cpu::dot", "dnnl::dot");
extend_dnnl_extend_op("relu", "cpu::relu", "dnnl::relu");
extend_dnnl_extend_op("concat", "dnnl::concat");
extend_op("add", "dnnl::add", true);
extend_op("mul", "dnnl::mul", true);
extend_op("convolution", "dnnl::convolution", true);
extend_op("dot", "dnnl::dot", true);
extend_op("relu", "dnnl::relu", true);
extend_op("contiguous", "cpu::contiguous", true);
extend_op("deconvolution", "cpu::deconvolution");
extend_op("elu", "cpu::elu");
......
#include <migraphx/cpu/migemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp>
#include <blaze/math/CustomMatrix.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}
template <class T, class F>
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{});
}
template <class F>
void migemm_tpl(
const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{
visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
}
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
void migemm(const argument& c_arg,
const argument& a_arg,
const argument& b_arg,
int32_t alpha,
int32_t beta)
{
migemm_tpl(c_arg, a_arg, b_arg, alpha, beta);
}
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -6,9 +6,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
template struct cpu_binary<op::mul>;
#if USE_DNNL
struct dnnl_mul : dnnl_extend_op<dnnl_mul, dnnl::binary, op::mul>
{
dnnl::binary::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
......@@ -19,7 +16,6 @@ struct dnnl_mul : dnnl_extend_op<dnnl_mul, dnnl::binary, op::mul>
m.at(DNNL_ARG_DST)};
}
};
#endif
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -123,7 +123,6 @@ struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
template struct cpu_pooling<avg_pool>;
template struct cpu_pooling<max_pool>;
#if USE_DNNL
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC}; }
......@@ -141,7 +140,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
to_dnnl_dims(op.padding)};
}
};
#endif
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -6,9 +6,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
template struct cpu_unary<op::relu>;
#if USE_DNNL
struct dnnl_relu : dnnl_extend_op<dnnl_relu, dnnl::eltwise_forward, op::relu>
{
dnnl::eltwise_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
......@@ -18,7 +15,6 @@ struct dnnl_relu : dnnl_extend_op<dnnl_relu, dnnl::eltwise_forward, op::relu>
m.at(DNNL_ARG_SRC_0)};
}
};
#endif
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -8,6 +8,7 @@
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/memory_coloring.hpp>
......@@ -36,7 +37,12 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) const
{
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::double_type);
unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{},
decompose{},
dead_code_elimination{},
simplify_reshapes{},
......
......@@ -44,7 +44,7 @@ struct mean
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return static_cast<T>(x / item_num);
return x / static_cast<T>(item_num);
}
};
......
......@@ -43,6 +43,15 @@ struct rocblas_gemm
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
std::size_t kdim = inputs[0].lens().size() - 1;
// k be multiple of 4
if(op.name() == "quant_dot" && (inputs[0].lens()[kdim] % 4) != 0)
{
MIGRAPHX_THROW("GPU_GEMM: size of A {" + to_string_range(inputs[0].lens()) +
"} and B {" + to_string_range(inputs[1].lens()) +
"} must be multiple of 4 for int8 type");
}
return op.compute_shape(in_shapes);
}
......
......@@ -799,7 +799,7 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
shape batch_shape{shape::int32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
using value_type = accumulator_type<typename decltype(input)::value_type>;
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
......@@ -808,7 +808,8 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[tuned_axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
batch_max[i] =
std::max<value_type>(batch_max[i], input(idx.begin(), idx.end()));
}
for(std::size_t j = 0; j < n_dims; ++j)
......
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m, std::set<migraphx::shape::type_t> types)
{
migraphx::run_passes(
m,
{migraphx::eliminate_data_type{std::move(types), migraphx::shape::float_type},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}});
}
TEST_CASE(simple)
{
migraphx::shape s{migraphx::shape::int8_type, {2, 2}};
migraphx::module mm1;
{
auto x = mm1.add_parameter("x", s);
auto y = mm1.add_parameter("y", s);
mm1.add_instruction(migraphx::make_op("add"), x, y);
}
run_pass(mm1, {migraphx::shape::int8_type});
migraphx::module mm2;
{
auto x = mm2.add_parameter("x", s);
auto y = mm2.add_parameter("y", s);
auto floatx = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x);
auto floaty = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), y);
auto add = mm2.add_instruction(migraphx::make_op("add"), floatx, floaty);
mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), add);
}
EXPECT(mm1 == mm2);
}
TEST_CASE(quant)
{
migraphx::shape s{migraphx::shape::int8_type, {2, 2}};
migraphx::module mm1;
{
auto x = mm1.add_parameter("x", s);
auto y = mm1.add_parameter("y", s);
mm1.add_instruction(migraphx::make_op("quant_dot"), x, y);
}
run_pass(mm1, {migraphx::shape::int8_type});
migraphx::module mm2;
{
auto x = mm2.add_parameter("x", s);
auto y = mm2.add_parameter("y", s);
auto floatx = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x);
auto floaty = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), y);
auto add = mm2.add_instruction(migraphx::make_op("dot"), floatx, floaty);
mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), add);
}
EXPECT(mm1 == mm2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -935,12 +935,6 @@ TEST_CASE(quant_dot_2args)
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}};
migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}};
throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
......
......@@ -12,19 +12,19 @@ struct test_convert : verify_program<test_convert>
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {8, 24}};
migraphx::shape sb{migraphx::shape::float_type, {24, 6}};
migraphx::shape sa{migraphx::shape::int8_type, {8, 24}};
migraphx::shape sb{migraphx::shape::int8_type, {24, 6}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto ia = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
auto ib = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
mm->add_instruction(migraphx::make_op("quant_dot"), ia, ib);
mm->add_instruction(migraphx::make_op("dot"), ia, ib);
return p;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment