Commit a96585d6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'ref_gemm_par' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into layernorm_half2

parents c1ec929c 9b19b73f
...@@ -106,7 +106,11 @@ bool argument::empty() const { return not m_data.get and m_data.sub.empty(); } ...@@ -106,7 +106,11 @@ bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const shape& argument::get_shape() const { return this->m_shape; } const shape& argument::get_shape() const { return this->m_shape; }
argument argument::reshape(const shape& s) const { return {s, this->m_data}; } argument argument::reshape(const shape& s) const
{
assert(s.element_space() <= this->get_shape().element_space());
return {s, this->m_data};
}
argument::data_t argument::data_t::share() const argument::data_t argument::data_t::share() const
{ {
......
...@@ -36,7 +36,6 @@ struct as_shape ...@@ -36,7 +36,6 @@ struct as_shape
{ {
return args.front().reshape(output_shape); return args.front().reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -67,7 +67,6 @@ struct broadcast ...@@ -67,7 +67,6 @@ struct broadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -51,7 +51,6 @@ struct flatten ...@@ -51,7 +51,6 @@ struct flatten
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -69,7 +69,6 @@ struct multibroadcast ...@@ -69,7 +69,6 @@ struct multibroadcast
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -75,7 +75,6 @@ struct reshape ...@@ -75,7 +75,6 @@ struct reshape
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -40,7 +40,6 @@ struct scalar ...@@ -40,7 +40,6 @@ struct scalar
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -84,7 +84,6 @@ struct squeeze ...@@ -84,7 +84,6 @@ struct squeeze
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -72,8 +72,6 @@ struct step ...@@ -72,8 +72,6 @@ struct step
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -56,7 +56,6 @@ struct transpose ...@@ -56,7 +56,6 @@ struct transpose
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -81,7 +81,6 @@ struct unsqueeze ...@@ -81,7 +81,6 @@ struct unsqueeze
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -226,6 +226,8 @@ struct shape ...@@ -226,6 +226,8 @@ struct shape
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const; std::size_t element_space() const;
std::size_t element_space() const;
private: private:
shape(std::shared_ptr<shape_impl> pimpl); shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
......
#include <migraphx/ref/gemm.hpp> #include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <blaze/math/CustomMatrix.h> #include <blaze/math/CustomMatrix.h>
namespace migraphx { namespace migraphx {
...@@ -74,8 +74,10 @@ void migemm_impl( ...@@ -74,8 +74,10 @@ void migemm_impl(
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx; auto a_idx = c_idx;
auto b_idx = c_idx; auto b_idx = c_idx;
double s = 0.0; double s = 0.0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment