Commit 6a9d7b64 authored by aska-0096's avatar aska-0096
Browse files

temp save

parent d4adc71a
......@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault,
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
16, // NPerBlock
32, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
1, // M Repeat
8, // N-Repeat
1, // N-Repeat
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8,
8,
true,
S<4, 64, 1>,
S<4, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
......@@ -59,8 +59,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8,
true,
1, // C shuffle (M Repeat) Per store
4, // C shuffle (N Repeat) Per store
S<1, 64, 1, 4>,
1, // C shuffle (N Repeat) Per store
S<1, 128, 1, 2>,
8>;
// clang-format on
......
......@@ -94,12 +94,14 @@ using DeviceGemmInstance =
TensorSpecB1,
TensorSpecC,
256,
// Gemm 0
128, // MPerBlock
128, // LPerBlock
4, // K0PerBlock
32, // KPerBlock
8, // K1
// Gemm 1
64, // NPerBlock
4, // L0PerBlock
32, // LPerBlock
8, // L1
16, // MPerWMMA
16, // LPerWMMA
......
......@@ -53,10 +53,10 @@ template <index_t NumDimG,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock,
ck::index_t K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim
ck::index_t K1, //
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t NPerBlock,
ck::index_t L0PerBlock,
ck::index_t LPerBlock,
ck::index_t L1,
ck::index_t MPerWMMA,
ck::index_t LPerWMMA,
......@@ -128,8 +128,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr index_t NumDimGemm1N = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimL;
static constexpr index_t KPerBlock = K0PerBlock * K1;
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle;
static constexpr auto I0 = Number<0>{};
......@@ -137,6 +135,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = LWaves == 1 ? false : true;
// static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
// static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
......@@ -146,13 +153,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
if constexpr(AEnableLds)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<K1>{});
}
else
{
return Transform::MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
WmmaK, Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{}, Number<K1>{})
}
}
static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
......@@ -170,7 +187,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number<L1>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor_BK0_L_BK1({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
......@@ -250,17 +267,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1,
AGridDesc,
B0GridDesc_BK0_L_BK1,
B1GridDesc_BL0_N_BL1,
CGridDesc_M_N,
// Tiling Family
MPerBlock,
LPerBlock,
K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim
K1, //
KPerBlock,
K1,
NPerBlock,
L0PerBlock,
LPerBlock,
L1,
MPerWMMA,
LPerWMMA,
......@@ -277,6 +294,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
true,
AEnableLds,
ABlockLdsAddExtraM,
B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder,
......@@ -285,6 +303,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_K1,
true,
B0EnableLds,
B0BlockLdsAddExtraL,
B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder,
......@@ -293,6 +312,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_L1,
false,
B1EnableLds,
B1BlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
......@@ -338,7 +358,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{DeviceOp::MakeB0GridDescriptor_BK0_L_BK1(
b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(
......@@ -404,7 +424,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c_grid_;
// Tensor Descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
AGridDesc a_grid_desc_ak0_m_ak1_;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
CGridDesc_M_N c_grid_desc_m_n_;
......@@ -463,7 +483,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0DataType,
B1DataType,
CDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::AGridDesc,
DeviceOp::B0GridDesc_BK0_L_BK1,
DeviceOp::B1GridDesc_BL0_N_BL1,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -741,11 +761,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< BlockSize << ", "
<< MPerBlock << ", "
<< LPerBlock << ", "
<< K0PerBlock << ", "
<< KPerBlock << ", "
<< K1 << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< L0PerBlock << ", "
<< LPerBlock << ", "
<< L1
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
......@@ -343,7 +343,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_buf = a_block_buf_switch;
// a_block_buf = a_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
......
......@@ -130,8 +130,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto B_K0 = BGridDesc_K0_N_K1{}.GetLength(I0);
static constexpr auto B_K1 = BGridDesc_K0_N_K1{}.GetLength(I2);
// FIX ME: To be deprecated
static constexpr auto K1 = Number<K1Value>{};
......@@ -273,6 +271,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
{
constexpr auto B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
......@@ -528,8 +529,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
}();
// printf("---------------K = %d\n", K);
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
......@@ -703,7 +702,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
// Shift Per SUB_K
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
// printf("a_block_slice_copy_step FirstKdim = %d\n", a_block_slice_copy_step[I0]);
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline
......
......@@ -1395,34 +1395,28 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, RawData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) ); apply intra-row swizzle permute
if constexpr(IntraRowSwizzlePerm)
{
temp = __builtin_amdgcn_permlane16( // 0x76543210, 0xfedcba98
temp,
type_convert<int>(v_this_row),
0xb3a29180,
0xf7e6d5c4,
1,
0);
// temp = __builtin_amdgcn_permlane16(
// temp,
// type_convert<int>(v_this_row),
// 0xb3a29180,
// 0xf7e6d5c4,
// 1,
// 0);
v_this_row = type_convert<SrcData>(temp);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, SwiData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) );
}
// apply inter-row permute.
temp = __builtin_amdgcn_permlanex16(temp,
type_convert<int>(v_this_row),
LowEightRowlaneIdx,
HighEightRowLaneIdx,
1,
0);
// temp = __builtin_amdgcn_permlanex16(temp,
// type_convert<int>(v_this_row),
// LowEightRowlaneIdx,
// HighEightRowLaneIdx,
// 1,
// 0);
v_theother_row = type_convert<SrcData>(temp);
// printf("tid: %03d, PermData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_theother_row)) );
if(get_thread_local_1d_id() % 32 < 16)
{
// apply type convert
......
......@@ -179,6 +179,26 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename AGridDesc_M_K, typename Number>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1(
const AGridDesc_M_K& a_grid_desc_m_k, const Number& WmmaK, const Number& MRepeat,
const Number& MWaves, const Number& MPerWmma, const Number& AK1)
{
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlcok;
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AKWmma = K / WmmaK;
constexpr auto AKRow = WmmaK / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AKWmma, Number<AKRow>{}, AK1)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, MWaves, MPerWmma))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
//
// B (alias of B0)
//
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment