Commit c26b46de authored by Anthony Chang's avatar Anthony Chang
Browse files

format

parent 15f1d4ad
......@@ -45,7 +45,6 @@ Kernel outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......
......@@ -384,7 +384,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return transform_tensor_descriptor(
ygrad_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)), make_pass_through_transform(O)),
make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)),
make_pass_through_transform(O)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
......@@ -456,7 +457,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
//
// static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock()
//
// dQ = alpha * dS * K
//
......@@ -700,8 +700,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
vgrad_grid_desc_n_o_{DeviceOp::MakeVGradGridDescriptor_N_O(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
/* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_{
DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
// batch offsets
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
......
......@@ -214,43 +214,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// PGrad Gemm
struct PGradGemmTile_M_N_O_
{
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O = Number<BlockSliceLength_O_ / SrcScalarPerVetor>{};
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVetor>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
// static_assert(BlockSliceLength_O_ % SrcScalarPerVetor == 0, "");
// static_assert(BlockSize_ % ThreadClusterLength_O == 0, "");
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
FloatGemmAcc,
BlockSize_,
Sequence<ThreadClusterLength_M, ThreadClusterLength_O>,
Sequence<0, 1>,
reduce::Add,
false>; // propagateNaN
// using ThreadReduceSrcDesc_M_O = decltype(make_naive_tensor_descriptor_packed(
// make_tuple(ThreadSliceLength_M, ThreadSliceLength_O)));
// using ThreadReduceDstDesc_M =
// decltype(make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceLength_M)));
// using ThreadwiseSumReduce =
// ThreadwiseReduction<FloatGemmAcc,
// ThreadReduceSrcDesc_M_O,
// ThreadReduceDstDesc_M,
// reduce::Add,
// false>; // propagateNaN
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
ThreadSliceLength_M * ThreadSliceLength_O,
......@@ -1271,11 +1250,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
//
// dP
//
constexpr auto y_thread_desc_m0_m1_o0_o1 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
YDotYGrad_M_O::ThreadSliceLength_M,
I1,
YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
......@@ -1316,10 +1292,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor(make_tuple(I1, P_M0, P_M1, P_M2),
make_tuple(P_M0 * P_M1 * P_M2, P_M1 * P_M2, P_M2, I1));
// y_dot_ygrad thread buffer for calculating sgrad; reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is tiled the same way
constexpr auto y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl =
lse_thread_desc_mblock_mrepeat_mwave_mperxdl;
lse_thread_desc_mblock_mrepeat_mwave_mperxdl; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
// TODO ANT: dP Gemm can reuse first blockwise gemm and pipeline
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
......@@ -1404,7 +1382,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds();
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]);
y_dot_ygrad_block_accum_buf.AtomicAdd(
idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]);
});
block_sync_lds();
......@@ -1416,7 +1395,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
#endif
// distribute to threads
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely accessed after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
y_dot_ygrad_block_accum_buf,
......@@ -1667,9 +1646,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// TODO ANT:
// shuffle dQ and write
#if 0
// TODO ANT:
// shuffle dQ and write
#if 0
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
......@@ -1865,7 +1844,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
});
}
#endif
#endif
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment