Commit 09f365a7 authored by wangshaojie6's avatar wangshaojie6
Browse files

avoid bank conflicts for wrw for all instance

parent f9fde06a
...@@ -52,19 +52,19 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device:: ...@@ -52,19 +52,19 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
32, // NPerXdl 32, // NPerXdl
2, // MXdlPerWave 2, // MXdlPerWave
2, // NXdlPerWave 2, // NXdlPerWave
S<1, 4, 32, 2>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1 2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM true, // ABlockLdsAddExtraM
S<1, 4, 32, 2>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1 2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
......
...@@ -107,7 +107,7 @@ ...@@ -107,7 +107,7 @@
// experimental feature: use __builtin_memcpy instead of pointer cast to access a vector from // experimental feature: use __builtin_memcpy instead of pointer cast to access a vector from
// pointer of scalar // pointer of scalar
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 1
// experimental feature: use __builtin_memcpy instead of union to do bit_cast // experimental feature: use __builtin_memcpy instead of union to do bit_cast
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
......
#pragma once
#include "common_header.hpp"
#include "multi_index_transform.hpp"
namespace ck {
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template <typename LowLengths>
struct Merge_v3_division_mod_for_wrw
{
static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_;
__host__ __device__ constexpr Merge_v3_division_mod_for_wrw() = default;
__host__ __device__ constexpr Merge_v3_division_mod_for_wrw(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}];
// division and mod
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp %= this->low_lengths_scan_[i];
});
idx_low(Number<NDimLow - 1>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_up_diff,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
constexpr auto INm1 = Number<NDimLow - 1>{};
index_t tmp = idx_up_new[I0];
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){
// //printf("%d, %d, %d\n", __LINE__, tmp, tmp2);
// //printf("%d, %d, %d\n",
// // __LINE__,
// // static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())),
// // static_cast<index_t>(this->low_lengths_scan_.At(Number<1>())));
// printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()), idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>()));
//}
//static_for<0, NDimLow - 1, 1>{}([&](auto i) {
// const index_t tmp2 = idx_low[i];
// idx_low(i) = tmp / this->low_lengths_scan_[i];
// idx_diff_low(i) = idx_low[i] - tmp2;
// tmp %= this->low_lengths_scan_[i];
//});
//const index_t tmp2 = idx_low[INm1];
//idx_low(INm1) = tmp;
//idx_diff_low(INm1) = idx_low[INm1] - tmp2;
idx_low(INm1) = tmp;
idx_diff_low(INm1) = idx_up_diff[I0];
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){
// //printf("%d, %d, %d\n", __LINE__, tmp, tmp2);
// printf("%d, %d, %d\n",
// __LINE__,
// static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())),
// static_cast<index_t>(this->low_lengths_scan_.At(Number<1>())));
// printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()), idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>()));
//}
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsScan>::value &&
is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Merge_v3_direct_division_mod_wrw, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v3_division_mod_for_wrw(const LowLengths& low_lengths)
{
return Merge_v3_division_mod_for_wrw<LowLengths>{low_lengths};
}
}
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "merge_transform_for_wrw.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
...@@ -10,6 +11,9 @@ ...@@ -10,6 +11,9 @@
#include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#define A_BLOCK_BANK_CONFLICT_FREE_WRW 1
#define B_BLOCK_BANK_CONFLICT_FREE_WRW 1
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -110,17 +114,46 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -110,17 +114,46 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() // M1 & N1
static constexpr auto ElePerBank = Number<64>{};
static constexpr auto M1PerBlock = Number<ElePerBank / K1Value>{};
static constexpr auto N1PerBlock = Number<ElePerBank / K1Value>{};
// M0 & N0
static constexpr auto M0PerBlock = Number<MPerBlock / M1PerBlock>{};
static constexpr auto N0PerBlock = Number<NPerBlock / M1PerBlock>{};
// M1 padding num
static constexpr auto M1Padding = Number<4>{};
static constexpr auto N1Padding = M1Padding;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = [&]() { constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
#if A_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1),
make_tuple(Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding), Number<M1PerBlock>{} * K1 + M1Padding, K1, I1));
constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})
);
return a_block_desc_k0_m_k1_tmp;
#else
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
#endif
} }
else else
{ {
...@@ -129,13 +162,78 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -129,13 +162,78 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_b_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
#if A_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding), Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding), Number<M1PerBlock>{} * K1 + M1Padding, K1, I1));
constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_b_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})
);
return a_block_desc_b_k0_m_k1_tmp;
#else
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1, Number<MPerBlock + 1>{} * K1, K1, I1));
#endif
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
return a_block_desc_b_k0_m_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = [&]() { constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
#if B_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1),
make_tuple(Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding), Number<N1PerBlock>{} * K1 + N1Padding, K1, I1));
constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})
);
return b_block_desc_k0_n_k1_tmp;
#else
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
#endif
} }
else else
{ {
...@@ -144,12 +242,65 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -144,12 +242,65 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
return b_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_b_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
#if B_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding), Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding), Number<N1PerBlock>{} * K1 + N1Padding, K1, I1));
constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_b_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})
);
return b_block_desc_b_k0_n_k1_tmp;
#else
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1, Number<NPerBlock + 1>{} * K1, K1, I1));
#endif
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
return b_block_desc_b_k0_n_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size = constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
...@@ -331,69 +482,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -331,69 +482,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = [&]() { constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto a_b_k0_m_k1_block_desc = [&]() { constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = [&]() { constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto b_b_k0_n_k1_block_desc = [&]() { constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
}
}();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
......
...@@ -266,6 +266,9 @@ struct DynamicBuffer ...@@ -266,6 +266,9 @@ struct DynamicBuffer
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else #else
//if(get_block_1d_id() == 0){
// printf("%d, tid=%d, i=%d\n", __LINE__, get_thread_local_1d_id(), i);
//}
*c_style_pointer_cast<X*>(&p_data_[i]) = x; *c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment