Commit eceea10a authored by Anthony Chang's avatar Anthony Chang
Browse files

clean up

parent 4ee34028
......@@ -216,21 +216,19 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
// b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
b0_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
// b1_n_o.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
......@@ -308,15 +306,6 @@ int main(int argc, char* argv[])
ref_gemm1_invoker.Run(ref_gemm1_argument);
// LogRangeAsType<float>(std::cout << "a_m_k: ", a_m_k.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b0_k_n : ", b0_k_n.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b1_n_o : ", b1_n_o.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "c_m_o_device_result : ", c_m_o_device_result.mData, ",") << std::endl;
std::cout << "b0_k_n(0, 0) = " << (float)b0_k_n(0, 0) << ", b0_k_n(1, 0) = " << (float)b0_k_n(1, 0)
<< ", b0_k_n(0, 1) = " << (float)b0_k_n(0, 1) << ", b0_k_n(1, 1) = " << (float)b0_k_n(1, 1)
<< std::endl;
return ck::utils::check_err(c_m_o_device_result.mData, c_m_o_host_result.mData) ? 0 : 1;
}
......
......@@ -158,22 +158,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
#if 0
if(!TransposeC && hipThreadIdx_x % 32 < 8)
{
printf("bid %zd tid %zd, a_mma = %d, %d, %d, %d, b_mma = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
a_origin[Number<0>{}],
a_origin[Number<1>{}],
a_origin[Number<2>{}],
a_origin[Number<3>{}],
b_origin[Number<0>{}],
b_origin[Number<1>{}],
b_origin[Number<2>{}],
b_origin[Number<3>{}]);
}
#endif
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
......
......@@ -81,21 +81,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
#if 0
if (std::is_same<Sequence<16,64,2>, BlockSliceLengths>::value)
{
auto s = src_block_slice_origin + thread_data_idx_begin;
auto d = dst_block_slice_origin + thread_data_idx_begin;
printf("bid %zd tid %zd, src origin %d %d %d, dst origin %d %d %d\n",
hipBlockIdx_x, hipThreadIdx_x,
s[Number<0>{}],
s[Number<1>{}],
s[Number<2>{}],
d[Number<0>{}],
d[Number<1>{}],
d[Number<2>{}]);
}
#endif
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
......
......@@ -162,7 +162,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit from DeviceGemmGemm subtype
{
using DeviceOp = DeviceGemmGemm_Xdl_CShuffle;
......@@ -553,7 +553,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
......@@ -561,7 +561,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
......@@ -655,24 +655,6 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
// TODO ANT: block id to ctilemap should infer acc0tile map
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
......@@ -685,7 +667,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
// TODO ANT: K for gemm1
// Gemm0_K
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -728,7 +710,8 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
arg.compute_base_ptr_of_batch_);
};
// TODO ANT: handle tail loops for gemm0 & gemm1
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
......
......@@ -50,7 +50,7 @@ template <typename FloatAB,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
......@@ -58,7 +58,7 @@ template <typename FloatAB,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
......@@ -75,6 +75,9 @@ template <typename FloatAB,
LoopScheduler LoopSched>
struct GridwiseBatchedGemmGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -91,8 +94,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
// Gemm1
static constexpr auto AccK1 = Number<4>{}; // TODO ANT: get from mfma_type.mfma_group_size
static constexpr auto AccK0 = Number<NPerBlock / AccK1.value>{};
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
......@@ -148,7 +149,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
// Sequence<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>{}.foo(); // <2, 1, 32>
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
......@@ -169,18 +169,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
// template <typename BlockwiseGemm>
// __host__ __device__ static constexpr auto
// GetAccBlockDescriptor_AK0PerBlock_MPerBlock_AK1(const BlockwiseGemm& blockwise_gemm)
// {
// constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// return make_naive_tensor_descriptor(
// make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1BK1),
// make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
// }
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
......@@ -266,26 +254,21 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false;
}
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
// check gemm1 gridwise gemm pipeline
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_outer_loop = N / NPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_outer_loop))
const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
......@@ -301,7 +284,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return true;
}
// TODO ANT: also consider gemm1 loop
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
......@@ -395,11 +377,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// for n in N0: // gemm1 summation loop
// for k in K0: // gemm0 summation loop
// acc0 += A[m][k] * B0[k][n] // acc0[m][n]
// acc1 += acc0 * B1[n][o] // acc1[m][o]
//
// set up Gemm0
//
......@@ -425,8 +402,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // TODO ANT: check if false
true,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
......@@ -456,8 +433,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // TODO ANT: check if false
true,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
......@@ -466,12 +443,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// TODO ANT: to refactor: blockwise gemm output layout
// TODO ANT: interwave scheduling
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
BlockSize,
FloatAB,
......@@ -509,8 +491,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const auto b_block_reset_copy_step = make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......@@ -520,7 +503,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// set up Gemm1
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to A data type
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
......@@ -533,47 +516,49 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto a1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / n4, 0, 0);
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), // NOTE: had to use merge_v3 or it will spit out weird errors
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// A1 thread descriptor for iterating Acc thread descriptor
// n2 num_groups_per_blk, n3 num_input_blks, n4 group_size // FIXME ANT: use block desc N3 instead of hardcoding
constexpr auto A1ThreadSlice = make_tuple(Number<Gemm1KPerBlock / n4 / 2>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
constexpr index_t A1K0 = A1ThreadSlice[I0];
constexpr index_t A1K1 = A1ThreadSlice[I2];
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto AccN3 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto A1ThreadSlice_K0_M_K1 =
make_tuple(Number<Gemm1KPerBlock / n4 / AccN3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
A1ThreadSlice,
make_tuple(A1ThreadSlice[I1] * A1ThreadSlice[I2], A1ThreadSlice[I2], I1));
// make_tuple(Number<A1K0>{}, Number<m0 * m1 * m2>{}, Number<n4>{}).foo(); // <8, 1, 4>
A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A1 matrix blockwise copy
// actually a threadwise copy. this variant needs to support RunRead() and RunWrite()
// TODO ANT: real blockwise copy from c_block_desc to c_thread_desc
// FIXME: this cannot copy from static_buffer to static_buffer because v3r1 uses integer offset
// which is useless against static_buffer because it requires integral constant
auto a1_blockwise_copy =
ThreadwiseTensorSliceTransfer_v1r3_Static<FloatGemmAcc,
FloatAB,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
Sequence<A1K0, m0 * m1 * m2, A1K1>,
Sequence<1, 0, 2>,
2,
n4>{};
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
FloatAB,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
......@@ -596,8 +581,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
true, // TODO ANT: check if false
true,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>(
b1_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
......@@ -637,19 +622,19 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
false,
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
make_tuple(0, 0, 0, 0)
}; // TransposeC
make_tuple(0, 0, 0, 0)}; // TransposeC
auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
const index_t num_gemm1_k_block_outer_loop = b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize C
c_thread_buf.Clear();
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
// j loop
do
{
// gemm0
......@@ -668,88 +653,40 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
blockwise_gemm,
acc_thread_buf,
num_k_block_main_loop);
#if 0
if(hipThreadIdx_x == 0)
printf("gemm1_k_block_outer_index %d, num_gemm1_k_block_outer_loop %d\n",
gemm1_k_block_outer_index,
num_gemm1_k_block_outer_loop);
#endif
#if 0
if (hipBlockIdx_x == 0 && hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 8) {
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, acc[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, I.value, acc_thread_buf[I]);
});
}
#endif
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
// FIXME ANT: do not need a1 copy here?
// a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// acc_thread_buf,
// a1_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// a1_thread_buf
// );
#if 0
if (hipThreadIdx_x % 32 < 4) {
static_for<0, a1_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, a1[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, 0, I.value, (float)a1_thread_buf[I]);
});
}
#endif
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
// TODO ANT: how to access static buffer while using tensor coordinate?
// a1_blockwise_copy.MoveSrcSliceWindow(acc_thread_desc_k0_m_k1,
// a1_block_slice_copy_step);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(b1_block_buf.p_data_, index_t(b1_block_desc_bk0_n_bk1.GetElementSpaceSize()));
}
#endif
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1K0>{}, I0, I0),
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
acc_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf
);
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 8) {
static_for<0, a1_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, a1[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, i.value, I.value, (float)a1_thread_buf[I]);
});
}
#endif
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf);
#if 0
if (hipThreadIdx_x % 32 < 8) {
static_for<0, c_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, c[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, i.value, I.value, c_thread_buf[I]);
});
}
#endif
block_sync_lds();
// a1_blockwise_copy.MoveSrcSliceWindow(acc_thread_desc_k0_m_k1,
// a1_block_slice_copy_step);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
......@@ -758,30 +695,26 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
}
// tail
{
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<(num_gemm1_k_block_inner_loop - 1) * A1K0>{}, I0, I0),
acc_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
a1_blockwise_copy.Run(
acc_thread_desc_k0_m_k1,
make_tuple(
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
acc_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf);
}
} // end gemm1
#if 0
if (hipThreadIdx_x % 32 < 8) {
static_for<0, c_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, c[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, num_gemm1_k_block_inner_loop - 1, I.value, c_thread_buf[I]);
});
}
#endif
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, a_block_reset_copy_step); // rewind K
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, b_block_reset_copy_step); // rewind K and step N
// don't need to rewind b1
} while (++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
a_block_reset_copy_step); // rewind K
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1,
b_block_reset_copy_step); // rewind K and step N
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle C and write out
{
......
......@@ -1145,10 +1145,6 @@ struct ThreadwiseTensorSliceTransfer_v4
src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
#if 0
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
#else
if constexpr(SrcBuffer::IsDynamicBuffer())
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
......@@ -1164,33 +1160,7 @@ struct ThreadwiseTensorSliceTransfer_v4
src_tmp_vector.template AsType<SrcData>()(i) =
src_buf[Number<src_offset>{}];
});
// if constexpr(StaticBufferTupleOfVector)
// {
// // constexpr auto offset_nd = SrcRefToOriginDisplacement{} + data_to_origin_disp_idx;
// // // offset_nd.foo();
// // constexpr auto offset = src_desc.CalculateOffset(offset_nd);
// // src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// // src_buf.template GetAsType<src_vector_t>(Number<offset>{});
// static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// // constexpr auto src_offset_nd = src_ref_to_origin_disp_idx +
// // data_to_origin_disp_idx + i * src_scalar_step_in_vector;
// // constexpr auto src_offset = src_desc.CalculateOffset(src_offset_nd);
// constexpr auto src_offset = src_desc.CalculateOffset(SrcRefToOriginDisplacement{});
// // SrcData s = src_buf[Number<src_offset>{}];
// SrcData s = src_buf[Number<0>{}];
// // apply type convert
// src_tmp_vector.template AsType<SrcData>()(i) = s;
// });
// }
// else
// {
// src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(),
// is_src_valid);
// }
}
#endif
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
......@@ -1236,16 +1206,14 @@ template <typename SrcData,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
// InMemoryDataOperationEnum DstInMemOp,
// index_t DstScalarStrideInVector,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3_Static
struct ThreadwiseTensorSliceTransfer_StaticToStatic
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_Static()
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic()
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
......
......@@ -30,17 +30,6 @@ enum struct MfmaInstr
mfma_f64_16x16x4f64
};
// template <typename T, bool TransposeC>
// struct mfma_base_type
// {
// template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
// __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
// {
// if constexpr (!TransposeC) T::run(a, b, reg_c);
// else T::run(b, a, reg_c);
// }
// };
template <MfmaInstr instr>
struct mfma_type;
......
......@@ -72,6 +72,7 @@ struct StaticBufferTupleOfVector
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
__host__ __device__ static constexpr index_t Size() { return s_per_buf; };
// Get S
// i is offset of S
template <index_t I>
......
......@@ -78,10 +78,4 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <index_t... Is>
__host__ __device__ constexpr Tuple<Number<Is>...> to_tuple(Sequence<Is...>)
{
return Tuple<Number<Is>...>(Number<Is>{}...);
}
} // namespace ck
......@@ -134,7 +134,7 @@ check_err(const std::vector<T>& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 128)
if(err_count < 5)
{
std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment