"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "bd503d893fe29004bc264eb7fb1f81af1ce7197b"
Commit 135eb63e authored by Alan Turner's avatar Alan Turner
Browse files

Move common CK device functions to ck.hpp

parent 791addfb
......@@ -154,6 +154,29 @@ struct ck_add
}
};
// In CK, the B matrix is ordered as N,K instead of K,N
template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
}
template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class T>
struct ck_gemm_softmax_gemm_settings
{
T scale{};
};
template <class... Ts>
constexpr ck_gemm_softmax_gemm_settings<Ts...> make_ck_gemm_softmax_gemm_settings(Ts... xs)
{
return {xs...};
}
#ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else
......
......@@ -33,17 +33,6 @@
namespace migraphx {
// In CK, the B matrix is ordered as N,K instead of K,N
template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
}
template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
{
......
......@@ -33,29 +33,6 @@
namespace migraphx {
// In CK, the B matrix is ordered as N,K instead of K,N
template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
}
template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class T>
struct ck_gemm_softmax_gemm_settings
{
T scale{};
};
template <class... Ts>
constexpr ck_gemm_softmax_gemm_settings<Ts...> make_ck_gemm_softmax_gemm_settings(Ts... xs)
{
return {xs...};
}
template <class G, class C, class A, class B, class B1, class Settings>
__device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1, Settings s)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment