Commit 65a00640 authored by Chao Liu's avatar Chao Liu
Browse files

fix bug in tensor adaptor

parent fc148cef
......@@ -45,6 +45,7 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
return ClusterDescriptor<Lengths, decltype(order)>{};
}
#if 1
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor_v2(
......@@ -64,9 +65,10 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
constexpr auto up_dim_new_top_ids = Sequence<0>{};
return make_simple_tensor_adaptor(
return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}
#endif
} // namespace ck
#endif
......@@ -1282,7 +1282,7 @@ struct DynamicFreeze
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
"wrong! inconsistent # of dimension");
idx_low = low_idx_;
......@@ -1299,7 +1299,7 @@ struct DynamicFreeze
const UpIdx& idx_up_new,
Number<Hack>)
{
idx_diff_low(Number<0>{}) = index_t{Number<0>{}};
idx_diff_low(Number<0>{}) = 0;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......@@ -1328,5 +1328,90 @@ struct DynamicFreeze
}
};
template <typename VectorSize, typename UpLength>
struct DynamicVectorize
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
VectorSize vector_size_;
__host__ __device__ constexpr DynamicVectorize() = default;
__host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size,
const UpLength& up_length)
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicVectorize, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};
} // namespace ck
#endif
......@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return DynamicFreeze<LowerIndex>{low_idx};
}
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length};
}
} // namespace ck
#endif
......@@ -235,15 +235,31 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
constexpr index_t ndim_low =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor1_min_hidden_id =
math::min(adaptor1_min_hidden_id,
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
constexpr index_t low_dim_hidden_id =
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value;
bool is_bottom_dim = false;
static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) {
if constexpr(low_dim_hidden_id ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[i])
{
is_bottom_dim = true;
}
});
if(!is_bottom_dim)
{
adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id);
}
});
constexpr index_t ndim_up =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
// get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id =
math::min(adaptor1_min_hidden_id,
......@@ -255,7 +271,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}();
constexpr index_t adaptor1_hidden_id_shift =
adaptor1_min_hidden_id - adaptor0_max_hidden_id + 1;
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
......@@ -276,7 +292,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod(idim_low_1) -= adaptor1_hidden_id_shift;
low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift;
});
// match hidden id
......@@ -322,7 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod(idim_up_1) -= adaptor1_hidden_id_shift;
up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift;
});
return up_dim_hidden_ids_1_mod;
......@@ -344,21 +360,21 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr auto top_dim_hidden_ids =
TensorAdaptor1::GetTopDimensionHiddenIds() - Number<adaptor1_hidden_id_shift>{};
TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
// put everything together
return TensorAdaptor<decltype(all_transforms),
decltype(all_low_dim_hidden_idss),
decltype(all_up_dim_hidden_idss),
decltype(bottom_dim_hidden_ids),
decltype(top_dim_hidden_ids)>{all_transforms};
return TensorAdaptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms& transforms,
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
......@@ -400,11 +416,19 @@ __host__ __device__ constexpr auto make_simple_tensor_adaptor(const Transforms&
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<Transforms,
decltype(low_dim_hidden_idss),
decltype(up_dim_hidden_idss),
decltype(bottom_dim_hidden_ids),
decltype(top_dim_hidden_ids)>{transforms};
return TensorAdaptor<remove_cv_t<Transforms>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
}
template <typename X,
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
}
} // namespace ck
......
......@@ -7,377 +7,6 @@
namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc,
typename BBlockDesc,
typename CThreadDesc,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
{
struct MatrixIndex
{
index_t row;
index_t col;
};
public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1()
: c_thread_begin_mtx_idx_{GetBeginOfCThreadDesc(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)},
b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
constexpr index_t M = ABlockDesc{}.GetLength(I1); // A is transposed
constexpr index_t N = BBlockDesc{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among");
static_assert(CThreadDesc{}.GetLength(I0) == GetCThreadDescLengths()[I0] &&
CThreadDesc{}.GetLength(I1) == GetCThreadDescLengths()[I1],
"wrong! CThreadDesc lengths is wrong");
}
__device__ static constexpr auto GetCThreadDescLengths()
{
constexpr auto I1 = Number<1>{};
constexpr index_t M = ABlockDesc{}.GetLength(I1); // A is transposed
constexpr index_t N = BBlockDesc{}.GetLength(I1);
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
}
__device__ static MatrixIndex GetBeginOfCThreadDesc(index_t thread_id)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto thread_cluster_idx =
c_thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{
thread_cluster_idx[I0] * MPerLevel0Cluster + thread_cluster_idx[I2] * MPerThreadSubC,
thread_cluster_idx[I1] * NPerLevel0Cluster + thread_cluster_idx[I3] * NPerThreadSubC};
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run_pipelined_2x2(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBlockBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = ABlockDesc{};
constexpr auto b_block_mtx = BBlockDesc{};
constexpr auto c_thread_mtx_desc = CThreadDesc{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx_desc.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx_desc.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2, "wrong! only support 2x2 pipeline");
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
make_tuple(Number<MPerThread>{}, Number<1>{}));
constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_mtx_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_mtx_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0),
a_block_buf,
a_thread_mtx_desc_,
make_tuple(I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0),
b_block_buf,
b_thread_mtx_desc_,
make_tuple(I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, Number<NPerLevel1Cluster>{}),
b_block_buf,
b_thread_mtx_desc_,
make_tuple(I0, Number<NPerThreadSubC>{}),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, Number<MPerLevel1Cluster>{}),
a_block_buf,
a_thread_mtx_desc_,
make_tuple(I0, Number<MPerThreadSubC>{}),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
b_thread_buf,
make_tuple(I0, I0),
c_thread_buf,
make_tuple(I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
b_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}));
// loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0),
a_block_buf,
a_thread_mtx_desc_,
make_tuple(I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(I0, I0),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, I0));
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0),
b_block_buf,
b_thread_mtx_desc_,
make_tuple(I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, Number<NPerLevel1Cluster>{}),
b_block_buf,
b_thread_mtx_desc_,
make_tuple(I0, Number<NPerThreadSubC>{}),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, Number<MPerLevel1Cluster>{}),
a_block_buf,
a_thread_mtx_desc_,
make_tuple(I0, Number<MPerThreadSubC>{}),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
b_thread_buf,
make_tuple(I0, I0),
c_thread_buf,
make_tuple(I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
b_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(I0, I0),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = CThreadDesc{}.GetLength(I0);
constexpr index_t NPerThread = CThreadDesc{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(a_block_buf, b_block_buf, c_thread_buf);
}
else
{
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
}
#else
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
#endif
}
private:
static constexpr auto c_thread_cluster_desc_ =
make_cluster_descriptor_v2(Sequence<MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster>{},
Sequence<0, 1, 2, 3>{});
static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, CThreadDesc{}.GetLength(Number<0>{})));
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, CThreadDesc{}.GetLength(Number<1>{})));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_mtx_desc_),
Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_M,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc,
decltype(b_thread_mtx_desc_),
Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmBDataPerRead_N,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
......@@ -399,7 +28,7 @@ template <index_t BlockSize,
typename CThreadDesc,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThreadLoop,
index_t KPerThread,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
......@@ -410,7 +39,7 @@ template <index_t BlockSize,
BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
......@@ -422,7 +51,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1
static constexpr auto I3 = Number<3>{};
public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1()
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
......@@ -433,8 +62,9 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == c_thread_cluster_desc_.GetElementSize(),
"wrong! wrong blocksize");
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
NLevel0ThreadCluster * NLevel1ThreadCluster,
"wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
......@@ -442,26 +72,42 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{
const auto thread_cluster_idx =
c_thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
constexpr index_t MPerLevel0Cluster = M1PerThread * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = N1PerThread * NLevel0ThreadCluster;
return make_multi_index(
0,
thread_cluster_idx[I0] * MPerLevel0Cluster + thread_cluster_idx[I2] * M1PerThread,
0,
thread_cluster_idx[I1] * NPerLevel0Cluster + thread_cluster_idx[I3] * N1PerThread);
}
__host__ __device__ static constexpr auto GetCThreadClusterDescriptor()
{
return make_cluster_descriptor_v2(Sequence<MLevel1ThreadCluster,
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1),
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
make_vectorize_transform(N0, 1),
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// thread position 4-d thread space
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster>{},
Sequence<0, 1, 2, 3>{});
NLevel0ThreadCluster))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
......@@ -479,13 +125,13 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<KPerThreadLoop>,
Sequence<KPerThread>,
Sequence<M0_, M1PerThread>,
Sequence<N0_, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
static_for<0, K, KPerThreadLoop>{}([&](auto k) {
static_for<0, K, KPerThread>{}([&](auto k) {
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
......@@ -510,25 +156,23 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1
}
private:
static constexpr auto c_thread_cluster_desc_ = GetCThreadClusterDescriptor();
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
// A[K, M0, M1]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<M0_>{}, Number<M1PerThread>{}));
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
// B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<N0_>{}, Number<N1PerThread>{}));
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<KPerThreadLoop, M0_, M1PerThread>,
Sequence<KPerThread, M0_, M1PerThread>,
Sequence<0, 1, 2>,
2,
AThreadCopyScalarPerVector_M1,
......@@ -541,7 +185,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerThreadLoop, N0_, N1PerThread>,
Sequence<KPerThread, N0_, N1PerThread>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N1,
......@@ -576,7 +220,7 @@ template <index_t BlockSize,
typename CThreadDesc,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThreadLoop,
index_t KPerThread,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
......@@ -587,7 +231,7 @@ template <index_t BlockSize,
BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
......@@ -599,7 +243,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
static constexpr auto I3 = Number<3>{};
public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2()
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
......@@ -610,8 +254,9 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == c_thread_cluster_desc_.GetElementSize(),
"wrong! wrong blocksize");
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
NLevel0ThreadCluster * NLevel1ThreadCluster,
"wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
......@@ -624,26 +269,42 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{
const auto thread_cluster_idx =
c_thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
constexpr index_t MPerLevel0Cluster = M1PerThread * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = N1PerThread * NLevel0ThreadCluster;
return make_multi_index(
0,
thread_cluster_idx[I0] * MPerLevel0Cluster + thread_cluster_idx[I2] * M1PerThread,
0,
thread_cluster_idx[I1] * NPerLevel0Cluster + thread_cluster_idx[I3] * N1PerThread);
}
__host__ __device__ static constexpr auto GetCThreadClusterDescriptor()
{
return make_cluster_descriptor_v2(Sequence<MLevel1ThreadCluster,
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1),
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
make_vectorize_transform(N0, 1),
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// thread position 4-d thread space
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster>{},
Sequence<0, 1, 2, 3>{});
NLevel0ThreadCluster))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
......@@ -661,7 +322,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<KPerThreadLoop>,
Sequence<KPerThread>,
Sequence<1, M1PerThread>,
Sequence<1, N1PerThread>>{};
......@@ -716,7 +377,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
......@@ -800,25 +461,23 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
}
private:
static constexpr auto c_thread_cluster_desc_ = GetCThreadClusterDescriptor();
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
// A[K, M0, M1]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<M0_>{}, Number<M1PerThread>{}));
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
// B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<N0_>{}, Number<N1PerThread>{}));
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<KPerThreadLoop, 1, M1PerThread>,
Sequence<KPerThread, 1, M1PerThread>,
Sequence<0, 1, 2>,
2,
AThreadCopyScalarPerVector_M1,
......@@ -831,7 +490,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerThreadLoop, 1, N1PerThread>,
Sequence<KPerThread, 1, N1PerThread>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N1,
......
......@@ -59,453 +59,6 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
}
#endif
#if 0
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
decltype(c_m0m1_n0n1_thread_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
auto c_thread_buf =
make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0m1_n0n1_thread_desc),
Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{}
.Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
Number<MPerThread>{},
Number<NRepeat>{},
Number<NPerThread>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block.col;
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
constexpr auto tmp = make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
#else
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......@@ -721,7 +274,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2<BlockSize,
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
......@@ -952,7 +505,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
#endif
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment