Commit 1dd11890 authored by turneram's avatar turneram
Browse files

Formatting

parent 07167910
...@@ -72,7 +72,6 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) ...@@ -72,7 +72,6 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
)__migraphx__"; )__migraphx__";
struct ck_gemm_compiler : compiler<ck_gemm_compiler> struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
std::vector<std::string> names() const { return {"ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm"}; }
......
...@@ -213,8 +213,8 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t) ...@@ -213,8 +213,8 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
AScalarPerVector, AScalarPerVector,
BScalarPerVector, BScalarPerVector,
CScalarPerVector>; CScalarPerVector>;
auto op = Add{}; auto op = Add{};
GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, b_desc, c_desc, op); GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, b_desc, c_desc, op);
} }
......
...@@ -164,12 +164,12 @@ struct Div ...@@ -164,12 +164,12 @@ struct Div
}; };
}; };
using InDataTypeTuple = ck::Tuple<ABDataType, ABDataType>; using InDataTypeTuple = ck::Tuple<ABDataType, ABDataType>;
using OutDataTypeTuple = ck::Tuple<CDataType>; using OutDataTypeTuple = ck::Tuple<CDataType>;
using ElementwiseOperation = Add; using ElementwiseOperation = Add;
static constexpr auto MPerThread = 8; static constexpr auto MPerThread = 8;
using InScalarPerVectorSeq = ck::Sequence<1, 8>; using InScalarPerVectorSeq = ck::Sequence<1, 8>;
using OutScalarPerVectorSeq = ck::Sequence<8>; using OutScalarPerVectorSeq = ck::Sequence<8>;
// using DeviceElementwiseAddInstance = // using DeviceElementwiseAddInstance =
// ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>, // ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
...@@ -186,7 +186,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t) ...@@ -186,7 +186,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
// auto idx = make_index(); // auto idx = make_index();
constexpr auto a_lens = get_shape_c<T>{}.lens; constexpr auto a_lens = get_shape_c<T>{}.lens;
constexpr auto a_strides = get_shape_c<T>{}.strides; constexpr auto a_strides = get_shape_c<T>{}.strides;
constexpr ck::index_t ndim = a_lens.size(); constexpr ck::index_t ndim = a_lens.size();
constexpr auto b_lens = get_shape_c<U>{}.lens; constexpr auto b_lens = get_shape_c<U>{}.lens;
constexpr auto b_strides = get_shape_c<U>{}.strides; constexpr auto b_strides = get_shape_c<U>{}.strides;
constexpr ck::index_t b_ndim = b_lens.size(); constexpr ck::index_t b_ndim = b_lens.size();
...@@ -197,47 +197,46 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t) ...@@ -197,47 +197,46 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>, ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ck::Tuple<CDataType>, ck::Tuple<CDataType>,
Add, Add,
ndim, ndim,
8, 8,
ck::Sequence<1, 8>, ck::Sequence<1, 8>,
ck::Sequence<8>>; ck::Sequence<8>>;
using shapes_t = std::array<ck::index_t, 3>; using shapes_t = std::array<ck::index_t, 3>;
//shapes_t lengths_abc; // shapes_t lengths_abc;
//copy(c_lens.begin(), c_lens.end(), lengths_abc); // copy(c_lens.begin(), c_lens.end(), lengths_abc);
shapes_t lengths_abc = {c_lens[0], c_lens[1], c_lens[2]}; shapes_t lengths_abc = {c_lens[0], c_lens[1], c_lens[2]};
//constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]); // constexpr auto lengths_abc = static_cast<shapes_t>(c_lens[0], c_lens[1], c_lens[2]);
constexpr auto strides_a = static_cast<shapes_t>(a_strides); constexpr auto strides_a = static_cast<shapes_t>(a_strides);
constexpr auto strides_b = static_cast<shapes_t>(b_strides); constexpr auto strides_b = static_cast<shapes_t>(b_strides);
constexpr auto strides_c = static_cast<shapes_t>(c_strides); constexpr auto strides_c = static_cast<shapes_t>(c_strides);
std::array<const void*, 2> input = {a_t.data(), std::array<const void*, 2> input = {a_t.data(), b_t.data()};
b_t.data()};
std::array<void*, 1> output = {c_t.data()}; std::array<void*, 1> output = {c_t.data()};
auto ck_add = DeviceElementwiseAddInstance{}; auto ck_add = DeviceElementwiseAddInstance{};
auto argument = ck_add.MakeArgumentPointer( auto argument = ck_add.MakeArgumentPointer(
lengths_abc, {strides_a, strides_b}, {strides_c}, input, output, Add{}); lengths_abc, {strides_a, strides_b}, {strides_c}, input, output, Add{});
using InGrid1dDescTuple = decltype(ck_add.GenerateInOutGrid1dDescTuple(ck::Number<ndim>{})); using InGrid1dDescTuple = decltype(ck_add.GenerateInOutGrid1dDescTuple(ck::Number<ndim>{}));
using OutGrid1dDescTuple = decltype(ck_add.GenerateInOutGrid1dDescTuple(ck::Number<ndim>{})); using OutGrid1dDescTuple = decltype(ck_add.GenerateInOutGrid1dDescTuple(ck::Number<ndim>{}));
using InDataTypePointerTuple = decltype(ck_add.GenerateInDataTypePointerTuple()); using InDataTypePointerTuple = decltype(ck_add.GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(ck_add.GenerateOutDataTypePointerTuple()); using OutDataTypePointerTuple = decltype(ck_add.GenerateOutDataTypePointerTuple());
using GridwiseElementwise = ck::GridwiseElementwise_1D<InGrid1dDescTuple, using GridwiseElementwise = ck::GridwiseElementwise_1D<InGrid1dDescTuple,
OutGrid1dDescTuple, OutGrid1dDescTuple,
InDataTypePointerTuple, InDataTypePointerTuple,
OutDataTypePointerTuple, OutDataTypePointerTuple,
ElementwiseOperation, ElementwiseOperation,
MPerThread, MPerThread,
InScalarPerVectorSeq, InScalarPerVectorSeq,
OutScalarPerVectorSeq>; OutScalarPerVectorSeq>;
GridwiseElementwise::Run(argument.in_grid_1d_desc_tuple_, GridwiseElementwise::Run(argument.in_grid_1d_desc_tuple_,
argument.out_grid_1d_desc_tuple_, argument.out_grid_1d_desc_tuple_,
argument.in_dev_buffers_, argument.in_dev_buffers_,
argument.out_dev_buffers_, argument.out_dev_buffers_,
argument.elementwise_op_); argument.elementwise_op_);
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -60,19 +60,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -60,19 +60,27 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
if(idx.global == 0) if(idx.global == 0)
{ {
printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m_k1.GetLength(I0)), int(a_grid_desc_k0_m_k1.GetLength(I1)), int(a_grid_desc_k0_m_k1.GetLength(I2))); printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n",
printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n_k1.GetLength(I0)), int(b_grid_desc_k0_n_k1.GetLength(I1)), int(b_grid_desc_k0_n_k1.GetLength(I2))); int(a_grid_desc_k0_m_k1.GetLength(I0)),
printf("c_grid_desc_m_n{%i, %i}\n", int(c_grid_desc_m_n.GetLength(I0)), int(c_grid_desc_m_n.GetLength(I1))); int(a_grid_desc_k0_m_k1.GetLength(I1)),
int(a_grid_desc_k0_m_k1.GetLength(I2)));
printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n",
int(b_grid_desc_k0_n_k1.GetLength(I0)),
int(b_grid_desc_k0_n_k1.GetLength(I1)),
int(b_grid_desc_k0_n_k1.GetLength(I2)));
printf("c_grid_desc_m_n{%i, %i}\n",
int(c_grid_desc_m_n.GetLength(I0)),
int(c_grid_desc_m_n.GetLength(I1)));
} }
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1; AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1; BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11; CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11;
DefaultBlock2CTileMap block_2_ctile_map; DefaultBlock2CTileMap block_2_ctile_map;
if(true or GridwiseGemm::CheckValidity( if(true or
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n)) GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n))
{ {
//printf("Is valid\n"); // printf("Is valid\n");
a_grid_desc_k0_m0_m1_k1 = a_grid_desc_k0_m0_m1_k1 =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1); GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
b_grid_desc_k0_n0_n1_k1 = b_grid_desc_k0_n0_n1_k1 =
...@@ -83,79 +91,86 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -83,79 +91,86 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
} }
else else
{ {
//printf("Not valid\n"); // printf("Not valid\n");
} }
if(idx.global == 0) if(idx.global == 0)
{ {
printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I2))); printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n",
printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2))); int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)),
printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n", int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)), int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1))); int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)),
int(a_grid_desc_k0_m0_m1_k1.GetLength(I2)));
printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n",
int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)),
int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)),
int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n",
int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)),
int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
} }
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop = const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
constexpr bool HasMainKBlockLoop = true; constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = true; constexpr bool HasDoubleTailKBlockLoop = true;
GridwiseGemm::Run(a_t.data(), GridwiseGemm::Run(a_t.data(),
b_t.data(), b_t.data(),
c_t.data(), c_t.data(),
p_t.data(), p_t.data(),
a_grid_desc_k0_m0_m1_k1, a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1, b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11, c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map, block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{}, ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{}); ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
constexpr bool HasMainKBlockLoop = true; constexpr bool HasMainKBlockLoop = true;
constexpr bool HasDoubleTailKBlockLoop = false; constexpr bool HasDoubleTailKBlockLoop = false;
GridwiseGemm::Run(a_t.data(), GridwiseGemm::Run(a_t.data(),
b_t.data(), b_t.data(),
c_t.data(), c_t.data(),
p_t.data(), p_t.data(),
a_grid_desc_k0_m0_m1_k1, a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1, b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11, c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map, block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{}, ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{}); ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
constexpr bool HasMainKBlockLoop = false; constexpr bool HasMainKBlockLoop = false;
constexpr bool HasDoubleTailKBlockLoop = true; constexpr bool HasDoubleTailKBlockLoop = true;
GridwiseGemm::Run(a_t.data(), GridwiseGemm::Run(a_t.data(),
b_t.data(), b_t.data(),
c_t.data(), c_t.data(),
p_t.data(), p_t.data(),
a_grid_desc_k0_m0_m1_k1, a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1, b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11, c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map, block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{}, ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{}); ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
else else
{ {
constexpr bool HasMainKBlockLoop = false; constexpr bool HasMainKBlockLoop = false;
constexpr bool HasDoubleTailKBlockLoop = false; constexpr bool HasDoubleTailKBlockLoop = false;
GridwiseGemm::Run(a_t.data(), GridwiseGemm::Run(a_t.data(),
b_t.data(), b_t.data(),
c_t.data(), c_t.data(),
p_t.data(), p_t.data(),
a_grid_desc_k0_m0_m1_k1, a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1, b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11, c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map, block_2_ctile_map,
ck::integral_constant<bool, HasMainKBlockLoop>{}, ck::integral_constant<bool, HasMainKBlockLoop>{},
ck::integral_constant<bool, HasDoubleTailKBlockLoop>{}); ck::integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
} }
......
...@@ -37,7 +37,7 @@ template <class T, class U, class V, class W> ...@@ -37,7 +37,7 @@ template <class T, class U, class V, class W>
__device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
{ {
static gemm tp{}; static gemm tp{};
using GridwiseGemm = decltype(tp.gg); using GridwiseGemm = decltype(tp.gg);
constexpr auto alens = get_shape_c<T>{}.lens; constexpr auto alens = get_shape_c<T>{}.lens;
constexpr auto m = alens[0]; constexpr auto m = alens[0];
constexpr auto k = alens[1]; constexpr auto k = alens[1];
...@@ -53,38 +53,51 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -53,38 +53,51 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
if(idx.global == 0) if(idx.global == 0)
printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs)); printf("%i %i %i, %i %i %i\n", int(m), int(n), int(k), int(as), int(bs), int(cs));
constexpr auto a_grid_desc_ak0_m_ak1 = tp.MakeAGridDescriptor_AK0_M_AK1(static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as)); constexpr auto a_grid_desc_ak0_m_ak1 = tp.MakeAGridDescriptor_AK0_M_AK1(
constexpr auto b_grid_desc_bk0_n_bk1 = tp.MakeBGridDescriptor_BK0_N_BK1(static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs)); static_cast<ck::index_t>(m), static_cast<ck::index_t>(k), static_cast<ck::index_t>(as));
constexpr auto c_grid_desc_m_n = tp.MakeCGridDescriptor_M_N(static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs)); constexpr auto b_grid_desc_bk0_n_bk1 = tp.MakeBGridDescriptor_BK0_N_BK1(
static_cast<ck::index_t>(k), static_cast<ck::index_t>(n), static_cast<ck::index_t>(bs));
constexpr auto c_grid_desc_m_n = tp.MakeCGridDescriptor_M_N(
static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(cs));
/* constexpr */ auto block_2_ctile_map = tp.MakeDefaultBlock2CTileMap(c_grid_desc_m_n); /* constexpr */ auto block_2_ctile_map = tp.MakeDefaultBlock2CTileMap(c_grid_desc_m_n);
if(idx.global == 0) if(idx.global == 0)
{ {
printf("a_grid_desc_ak0_m_ak1{%i, %i, %i}\n", int(a_grid_desc_ak0_m_ak1.GetLength(I0)), int(a_grid_desc_ak0_m_ak1.GetLength(I1)), int(a_grid_desc_ak0_m_ak1.GetLength(I2))); printf("a_grid_desc_ak0_m_ak1{%i, %i, %i}\n",
printf("b_grid_desc_bk0_n_bk1{%i, %i, %i}\n", int(b_grid_desc_bk0_n_bk1.GetLength(I0)), int(b_grid_desc_bk0_n_bk1.GetLength(I1)), int(b_grid_desc_bk0_n_bk1.GetLength(I2))); int(a_grid_desc_ak0_m_ak1.GetLength(I0)),
printf("c_grid_desc_m_n{%i, %i}\n", int(c_grid_desc_m_n.GetLength(I0)), int(c_grid_desc_m_n.GetLength(I1))); int(a_grid_desc_ak0_m_ak1.GetLength(I1)),
int(a_grid_desc_ak0_m_ak1.GetLength(I2)));
printf("b_grid_desc_bk0_n_bk1{%i, %i, %i}\n",
int(b_grid_desc_bk0_n_bk1.GetLength(I0)),
int(b_grid_desc_bk0_n_bk1.GetLength(I1)),
int(b_grid_desc_bk0_n_bk1.GetLength(I2)));
printf("c_grid_desc_m_n{%i, %i}\n",
int(c_grid_desc_m_n.GetLength(I0)),
int(c_grid_desc_m_n.GetLength(I1)));
} }
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock{}; c_grid_desc_mblock_mperblock_nblock_nperblock{};
if(true or GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, if(true or
b_grid_desc_bk0_n_bk1, GridwiseGemm::CheckValidity(
c_grid_desc_m_n, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_m_n, block_2_ctile_map))
block_2_ctile_map))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock = c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
c_grid_desc_m_n);
} }
// if(idx.global == 0) // if(idx.global == 0)
// { // {
// printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n", int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I2))); // printf("a_grid_desc_k0_m0_m1_k1{%i, %i, %i}\n",
// printf("b_grid_desc_k0_n0_n1_k1{%i, %i, %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2))); // int(a_grid_desc_k0_m0_m1_k1.GetLength(I0)), int(a_grid_desc_k0_m0_m1_k1.GetLength(I1)),
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n", int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)), int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1))); // int(a_grid_desc_k0_m0_m1_k1.GetLength(I2))); printf("b_grid_desc_k0_n0_n1_k1{%i, %i,
// %i}\n", int(b_grid_desc_k0_n0_n1_k1.GetLength(I0)),
// int(b_grid_desc_k0_n0_n1_k1.GetLength(I1)), int(b_grid_desc_k0_n0_n1_k1.GetLength(I2)));
// printf("c_grid_desc_m0_m10_m11_n0_n10_n11{%i, %i}\n",
// int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0)),
// int(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I1)));
// } // }
const auto K = const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
auto a_element_op = tp.a_element_op; auto a_element_op = tp.a_element_op;
auto b_element_op = tp.b_element_op; auto b_element_op = tp.b_element_op;
auto c_element_op = tp.c_element_op; auto c_element_op = tp.c_element_op;
...@@ -93,31 +106,31 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t) ...@@ -93,31 +106,31 @@ __device__ void ck_gemm(const T& a_t, const U& b_t, const V& c_t, const W& p_t)
{ {
constexpr bool HasMainKBlockLoop = true; constexpr bool HasMainKBlockLoop = true;
GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(), GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
b_t.data(), b_t.data(),
c_t.data(), c_t.data(),
p_t.data(), p_t.data(),
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
} }
else else
{ {
constexpr bool HasMainKBlockLoop = false; constexpr bool HasMainKBlockLoop = false;
GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(), GridwiseGemm::template Run<HasMainKBlockLoop>(a_t.data(),
b_t.data(), b_t.data(),
c_t.data(), c_t.data(),
p_t.data(), p_t.data(),
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
} }
} }
......
...@@ -51,8 +51,8 @@ static constexpr auto I5 = ck::Number<5>{}; ...@@ -51,8 +51,8 @@ static constexpr auto I5 = ck::Number<5>{};
static constexpr ck::index_t K1 = 1; static constexpr ck::index_t K1 = 1;
static constexpr auto K1Number = ck::Number<K1>{}; static constexpr auto K1Number = ck::Number<K1>{};
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// using ALayout = Row; // using ALayout = Row;
// using BLayout = Row; // using BLayout = Row;
// using CLayout = Row; // using CLayout = Row;
...@@ -86,7 +86,7 @@ using S = ck::Sequence<Is...>; ...@@ -86,7 +86,7 @@ using S = ck::Sequence<Is...>;
// static constexpr ck::index_t NPerBlock = 128; // static constexpr ck::index_t NPerBlock = 128;
// static constexpr ck::index_t KPerBlock = 32; // static constexpr ck::index_t KPerBlock = 32;
// static constexpr ck::index_t AK1 = 8; // static constexpr ck::index_t AK1 = 8;
// static constexpr ck::index_t BK1 = 2; // static constexpr ck::index_t BK1 = 2;
// static constexpr ck::index_t MPerXDL = 32; // static constexpr ck::index_t MPerXDL = 32;
// static constexpr ck::index_t NPerXDL = 32; // static constexpr ck::index_t NPerXDL = 32;
// static constexpr ck::index_t MXdlPerWave = 4; // static constexpr ck::index_t MXdlPerWave = 4;
...@@ -126,7 +126,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -126,7 +126,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
{ {
} }
__host__ __device__ constexpr ck::index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ __device__ constexpr ck::index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
...@@ -156,7 +157,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -156,7 +157,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt
ck::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; ck::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return ck::make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, return ck::make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt); idx_N0_M01_local / M01_adapt);
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
...@@ -166,7 +167,10 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -166,7 +167,10 @@ struct BlockToCTileMap_M00_N0_M01Adapt
return true; // always valid provided that user gets grid size from CalculateGridSize() return true; // always valid provided that user gets grid size from CalculateGridSize()
} }
__host__ __device__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } __host__ __device__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
private: private:
ck::index_t M01_; ck::index_t M01_;
...@@ -217,7 +221,8 @@ template <typename ALayout, ...@@ -217,7 +221,8 @@ template <typename ALayout,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()> ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct TuningParams struct TuningParams
{ {
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA) static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(ck::index_t MRaw, ck::index_t KRaw, ck::index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(ck::is_same_v<ck::tensor_layout::gemm::RowMajor, ALayout>) if constexpr(ck::is_same_v<ck::tensor_layout::gemm::RowMajor, ALayout>)
...@@ -239,88 +244,90 @@ struct TuningParams ...@@ -239,88 +244,90 @@ struct TuningParams
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding || if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
//assert(K % AK1 == 0); // assert(K % AK1 == 0);
const auto AK0 = K / AK1; const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = const auto a_grid_desc_m_k = transform_tensor_descriptor(
transform_tensor_descriptor(a_grid_desc_mraw_kraw, a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad), ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_right_pad_transform(KRaw, KPad)), ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
transform_tensor_descriptor(a_grid_desc_m_k, a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)), ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(M)), ck::make_pass_through_transform(M)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding || else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
//assert(KRaw % AK1 == 0); // assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1; const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
transform_tensor_descriptor(a_grid_desc_mraw_kraw, a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)), ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_right_pad_transform(MRaw, MPad)), ck::make_right_pad_transform(MRaw, MPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding || else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
//assert(K % AK1 == 0); // assert(K % AK1 == 0);
const auto AK0 = K / AK1; const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor( const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw, a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw), ck::make_right_pad_transform(KRaw, KPad)), ck::make_tuple(ck::make_pass_through_transform(MRaw),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
transform_tensor_descriptor(a_grid_desc_m_k, a_grid_desc_m_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)), ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)), ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else else
{ {
// not pad M or K // not pad M or K
//assert(KRaw % AK1 == 0); // assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1; const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
transform_tensor_descriptor(a_grid_desc_mraw_kraw, a_grid_desc_mraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)), ck::make_tuple(make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(MRaw)), ck::make_pass_through_transform(MRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
} }
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB) static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(ck::index_t KRaw, ck::index_t NRaw, ck::index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -342,88 +349,90 @@ struct TuningParams ...@@ -342,88 +349,90 @@ struct TuningParams
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding || if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
//assert(K % BK1 == 0); // assert(K % BK1 == 0);
const auto BK0 = K / BK1; const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = const auto b_grid_desc_n_k = transform_tensor_descriptor(
transform_tensor_descriptor(b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_right_pad_transform(NRaw, NPad), ck::make_tuple(ck::make_right_pad_transform(NRaw, NPad),
ck::make_right_pad_transform(KRaw, KPad)), ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
transform_tensor_descriptor(b_grid_desc_n_k, b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)), ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(N)), ck::make_pass_through_transform(N)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding || else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
//assert(KRaw % BK1 == 0); // assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1; const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
transform_tensor_descriptor(b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)), ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_right_pad_transform(NRaw, NPad)), ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding || else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
//assert(K % BK1 == 0); // assert(K % BK1 == 0);
const auto BK0 = K / BK1; const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor( const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_pass_through_transform(NRaw), ck::make_right_pad_transform(KRaw, KPad)), ck::make_tuple(ck::make_pass_through_transform(NRaw),
ck::make_right_pad_transform(KRaw, KPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
transform_tensor_descriptor(b_grid_desc_n_k, b_grid_desc_n_k,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)), ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)), ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else else
{ {
// not pad N or K // not pad N or K
//assert(KRaw % BK1 == 0); // assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1; const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
transform_tensor_descriptor(b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)), ck::make_tuple(make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(NRaw)), ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
} }
static constexpr auto MakeCGridDescriptor_M_N(ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC) static constexpr auto
MakeCGridDescriptor_M_N(ck::index_t MRaw, ck::index_t NRaw, ck::index_t StrideC)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value)
...@@ -445,32 +454,35 @@ struct TuningParams ...@@ -445,32 +454,35 @@ struct TuningParams
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding || if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad), c_grid_desc_mraw_nraw,
ck::make_right_pad_transform(NRaw, NPad)), ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
} }
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding || else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::MKPadding)
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_mraw_nraw, c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad), ck::make_pass_through_transform(NRaw)), ck::make_tuple(ck::make_right_pad_transform(MRaw, MPad),
ck::make_pass_through_transform(NRaw)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
} }
else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding || else if constexpr(GemmSpec == ck::tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding) GemmSpec == ck::tensor_operation::device::GemmSpecialization::NKPadding)
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_mraw_nraw, c_grid_desc_mraw_nraw,
ck::make_tuple(ck::make_pass_through_transform(MRaw), ck::make_right_pad_transform(NRaw, NPad)), ck::make_tuple(ck::make_pass_through_transform(MRaw),
ck::make_right_pad_transform(NRaw, NPad)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
} }
...@@ -544,7 +556,7 @@ struct TuningParams ...@@ -544,7 +556,7 @@ struct TuningParams
}; };
using gemm = TuningParams using gemm = TuningParams
// clang-format off // clang-format off
//| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
......
...@@ -52,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{}; ...@@ -52,7 +52,7 @@ static constexpr auto K1Number = ck::Number<K1>{};
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Row;//Col; using ALayout = Row; // Col;
using BLayout = Row; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
...@@ -216,39 +216,39 @@ using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); ...@@ -216,39 +216,39 @@ using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using GridwiseGemm = using GridwiseGemm =
ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize, ck::GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType, ADataType,
AccDataType, AccDataType,
CDataType, CDataType,
ck::InMemoryDataOperationEnum::Set, ck::InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
M1N1ThreadClusterM1Xs, M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs, M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
...@@ -267,8 +267,7 @@ using BGridDesc_K0_N0_N1_K1 = ...@@ -267,8 +267,7 @@ using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -37,8 +37,10 @@ ...@@ -37,8 +37,10 @@
// migraphx::shape m2_shape{migraphx::shape::float_type, {4096, 4096}}; // migraphx::shape m2_shape{migraphx::shape::float_type, {4096, 4096}};
// auto l1 = mm->add_parameter("1", m1_shape); // auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape); // auto l2 = mm->add_parameter("2", m2_shape);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); // // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}),
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); // l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}),
// l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2); // mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
...@@ -54,15 +56,15 @@ struct test_ck_gemm : verify_program<test_ck_gemm> ...@@ -54,15 +56,15 @@ struct test_ck_gemm : verify_program<test_ck_gemm>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::half_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::half_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::half_type, {3, 4}};
std::vector<float> v1(2*3, 1); std::vector<float> v1(2 * 3, 1);
std::iota(v1.begin(), v1.end(), 1); std::iota(v1.begin(), v1.end(), 1);
std::vector<float> v2(3*4, 1); std::vector<float> v2(3 * 4, 1);
//std::iota(v2.begin(), v2.end(), 1); // std::iota(v2.begin(), v2.end(), 1);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// auto l1 = mm->add_parameter("1", m1_shape); // auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m2_shape); // auto l2 = mm->add_parameter("2", m2_shape);
//l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2); mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment