"...composable_kernel_rocm.git" did not exist on "ca42e9101d7eb1930dad87407dcf4d36693ecf65"
Commit c26b46de authored by Anthony Chang's avatar Anthony Chang
Browse files

format

parent 15f1d4ad
...@@ -45,7 +45,6 @@ Kernel outputs: ...@@ -45,7 +45,6 @@ Kernel outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -384,7 +384,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -384,7 +384,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return transform_tensor_descriptor( return transform_tensor_descriptor(
ygrad_grid_desc_m_o, ygrad_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)), make_pass_through_transform(O)), make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)),
make_pass_through_transform(O)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
...@@ -456,7 +457,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -456,7 +457,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// //
// static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock() // static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock()
// //
// dQ = alpha * dS * K // dQ = alpha * dS * K
// //
...@@ -700,8 +700,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -700,8 +700,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
vgrad_grid_desc_n_o_{DeviceOp::MakeVGradGridDescriptor_N_O( vgrad_grid_desc_n_o_{DeviceOp::MakeVGradGridDescriptor_N_O(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
/* PTrans descriptor will be constructed in kernel */ /* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_{ ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
// batch offsets // batch offsets
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
......
...@@ -214,43 +214,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -214,43 +214,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// PGrad Gemm // PGrad Gemm
struct PGradGemmTile_M_N_O_ struct PGradGemmTile_M_N_O_
{ {
}; };
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType); static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O = Number<BlockSliceLength_O_ / SrcScalarPerVetor>{}; static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVetor>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{}; static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{};
static constexpr auto ThreadSliceLength_M = static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{}; Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
// static_assert(BlockSliceLength_O_ % SrcScalarPerVetor == 0, "");
// static_assert(BlockSize_ % ThreadClusterLength_O == 0, "");
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, ""); static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
FloatGemmAcc,
BlockSize_,
Sequence<ThreadClusterLength_M, ThreadClusterLength_O>,
Sequence<0, 1>,
reduce::Add,
false>; // propagateNaN
// using ThreadReduceSrcDesc_M_O = decltype(make_naive_tensor_descriptor_packed(
// make_tuple(ThreadSliceLength_M, ThreadSliceLength_O)));
// using ThreadReduceDstDesc_M =
// decltype(make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceLength_M)));
// using ThreadwiseSumReduce =
// ThreadwiseReduction<FloatGemmAcc,
// ThreadReduceSrcDesc_M_O,
// ThreadReduceDstDesc_M,
// reduce::Add,
// false>; // propagateNaN
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType, DataType,
ThreadSliceLength_M * ThreadSliceLength_O, ThreadSliceLength_M * ThreadSliceLength_O,
...@@ -1271,11 +1250,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1271,11 +1250,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// //
// dP // dP
// //
constexpr auto y_thread_desc_m0_m1_o0_o1 = constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
YDotYGrad_M_O::ThreadSliceLength_M,
I1,
YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc = constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1, make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M, YDotYGrad_M_O::ThreadClusterLength_M,
...@@ -1316,10 +1292,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1316,10 +1292,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl = constexpr auto y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor(make_tuple(I1, P_M0, P_M1, P_M2), make_naive_tensor_descriptor(make_tuple(I1, P_M0, P_M1, P_M2),
make_tuple(P_M0 * P_M1 * P_M2, P_M1 * P_M2, P_M2, I1)); make_tuple(P_M0 * P_M1 * P_M2, P_M1 * P_M2, P_M2, I1));
// y_dot_ygrad thread buffer for calculating sgrad; reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is tiled the same way
constexpr auto y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl = constexpr auto y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl =
lse_thread_desc_mblock_mrepeat_mwave_mperxdl; lse_thread_desc_mblock_mrepeat_mwave_mperxdl; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
// TODO ANT: dP Gemm can reuse first blockwise gemm and pipeline
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc, FloatGemmAcc,
...@@ -1404,7 +1382,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1404,7 +1382,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); block_sync_lds();
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) { static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM; const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]); y_dot_ygrad_block_accum_buf.AtomicAdd(
idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]);
}); });
block_sync_lds(); block_sync_lds();
...@@ -1416,7 +1395,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1416,7 +1395,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
#endif #endif
// distribute to threads // distribute y_dot_ygrad to threads; LDS accum buffer can be safely accessed after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run( y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl, y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
y_dot_ygrad_block_accum_buf, y_dot_ygrad_block_accum_buf,
...@@ -1667,9 +1646,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1667,9 +1646,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// TODO ANT: // TODO ANT:
// shuffle dQ and write // shuffle dQ and write
#if 0 #if 0
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...@@ -1865,7 +1844,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1865,7 +1844,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}); });
} }
#endif #endif
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment