Commit c47fd466 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 980ca67d
......@@ -34,7 +34,7 @@ struct quant_dot
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if (t != shape::int8_type)
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
}
......@@ -48,16 +48,16 @@ struct quant_dot
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{
MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
auto out_lens = a.lens();
......@@ -69,7 +69,7 @@ struct quant_dot
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if (inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
......
......@@ -45,12 +45,8 @@ struct is_fast_gemm_type<float> : std::true_type
};
template <class T, class F>
void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
F alpha,
F beta,
std::true_type)
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::true_type)
{
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
......@@ -67,12 +63,8 @@ void migemm_impl(tensor_view<T> cmat,
}
template <class T, class F>
void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
F alpha,
F beta,
std::false_type)
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
......@@ -96,8 +88,7 @@ void migemm_impl(tensor_view<T> cmat,
}
template <class T, class F>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
auto lens = amat.get_shape().lens();
bool batch_mul =
......@@ -113,9 +104,8 @@ void migemm_impl(
}
}
template<class F>
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
template <class F>
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta)
{
visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
......
......@@ -8,9 +8,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
template<class T>
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, T alpha, T beta);
template <class T>
void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, T alpha, T beta);
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -435,15 +435,13 @@ struct cpu_quant_gemm
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
args.at(0).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
args.at(1).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
if(args.size() == 3)
......@@ -879,7 +877,7 @@ struct cpu_apply
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment