Commit c47fd466 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 980ca67d
...@@ -34,7 +34,7 @@ struct quant_dot ...@@ -34,7 +34,7 @@ struct quant_dot
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if (t != shape::int8_type) if(t != shape::int8_type)
{ {
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t"); MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
} }
...@@ -48,16 +48,16 @@ struct quant_dot ...@@ -48,16 +48,16 @@ struct quant_dot
if(!std::equal( if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
} }
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
{ {
MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
} }
auto out_lens = a.lens(); auto out_lens = a.lens();
...@@ -69,7 +69,7 @@ struct quant_dot ...@@ -69,7 +69,7 @@ struct quant_dot
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}"); "}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
} }
if (inputs.size() == 3 && inputs.at(2).type() != shape::int32_type) if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{ {
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32"); MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
} }
......
...@@ -45,12 +45,8 @@ struct is_fast_gemm_type<float> : std::true_type ...@@ -45,12 +45,8 @@ struct is_fast_gemm_type<float> : std::true_type
}; };
template <class T, class F> template <class T, class F>
void migemm_impl(tensor_view<T> cmat, void migemm_impl(
tensor_view<T> amat, tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::true_type)
tensor_view<T> bmat,
F alpha,
F beta,
std::true_type)
{ {
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
...@@ -67,12 +63,8 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -67,12 +63,8 @@ void migemm_impl(tensor_view<T> cmat,
} }
template <class T, class F> template <class T, class F>
void migemm_impl(tensor_view<T> cmat, void migemm_impl(
tensor_view<T> amat, tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
tensor_view<T> bmat,
F alpha,
F beta,
std::false_type)
{ {
std::size_t n_dims = cmat.get_shape().lens().size(); std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2; std::size_t dim_0 = n_dims - 2;
...@@ -96,8 +88,7 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -96,8 +88,7 @@ void migemm_impl(tensor_view<T> cmat,
} }
template <class T, class F> template <class T, class F>
void migemm_impl( void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{ {
auto lens = amat.get_shape().lens(); auto lens = amat.get_shape().lens();
bool batch_mul = bool batch_mul =
...@@ -113,9 +104,8 @@ void migemm_impl( ...@@ -113,9 +104,8 @@ void migemm_impl(
} }
} }
template<class F> template <class F>
void migemm( void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{ {
visit_all(c_arg, a_arg, b_arg)( visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); }); [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
......
...@@ -8,9 +8,8 @@ namespace migraphx { ...@@ -8,9 +8,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
template<class T> template <class T>
void migemm( void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, T alpha, T beta);
const argument& c_arg, const argument& a_arg, const argument& b_arg, T alpha, T beta);
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -435,15 +435,13 @@ struct cpu_quant_gemm ...@@ -435,15 +435,13 @@ struct cpu_quant_gemm
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}}; argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}}; argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) { arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) { args.at(0).visit(
std::copy(input.begin(), input.end(), output.begin()); [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
}); });
arg_1.visit([&](auto output) { arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) { args.at(1).visit(
std::copy(input.begin(), input.end(), output.begin()); [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
}); });
if(args.size() == 3) if(args.size() == 3)
...@@ -879,7 +877,7 @@ struct cpu_apply ...@@ -879,7 +877,7 @@ struct cpu_apply
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>(); apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["dot"] = extend_op<cpu_gemm, op::dot>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>(); apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>(); apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment