Commit 3db3fe42 authored by Anthony Chang's avatar Anthony Chang
Browse files

dP

parent c26b46de
...@@ -364,12 +364,29 @@ int run(int argc, char* argv[]) ...@@ -364,12 +364,29 @@ int run(int argc, char* argv[])
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
break; break;
default: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
// ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
break;
default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); // dy[g0, g1, m, o]
} }
// calculate y & log-sum-exp beforehand // calculate y & log-sum-exp beforehand
......
...@@ -30,7 +30,7 @@ template <typename DataType, ...@@ -30,7 +30,7 @@ template <typename DataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename VGridDesc_N0_O_N1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename LSEGridDesc_M, typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
...@@ -186,36 +186,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -186,36 +186,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
[](auto I) { return GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2().At(I); }, [](auto I) { return GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2().At(I); },
Number<4>{}); Number<4>{});
} }
// template <typename PBlockDesc_M0_N_M1>
// __host__ __device__ static constexpr auto
// MakePMmaTileDescriptor_N0_N1_N2_M(const PBlockDesc_M0_N_M1&)
// {
// constexpr auto lengths = GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2();
// return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<lengths[I0], lengths[I2],
// lengths[I4]>(
// PBlockDesc_M0_N_M1{});
// }
// template <typename BBlockDesc_BK0_N_BK1>
// __host__ __device__ static constexpr auto
// MakeYGradMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
// {
// constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
// return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
// BBlockDesc_BK0_N_BK1{});
// }
}; };
using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
// PGrad Gemm
struct PGradGemmTile_M_N_O_
{
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
...@@ -363,7 +337,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -363,7 +337,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
...@@ -374,7 +348,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -374,7 +348,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1);
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
...@@ -472,6 +446,81 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -472,6 +446,81 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
// PGrad Gemm has the same layout as P Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
private:
static constexpr auto ygrad_block_desc_o0_m_o1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto v_block_desc_o0_n_o1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
public:
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
FloatGemmAcc,
decltype(ygrad_block_desc_o0_m_o1),
decltype(v_block_desc_o0_n_o1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(ygrad_block_desc_o0_m_o1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(v_block_desc_o0_n_o1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>;
// Should have made all input tensors 2D and transform them into appropriate 3D form in
// kernel to make things more concise - if we can get the compiler to behave
template <typename YGradGridDesc_M0_O_M1_>
__device__ static const auto
MakeYGradGridDesc_O0_M_O1(const YGradGridDesc_M0_O_M1_& ygrad_grid_desc_m0_o_m1)
{
const auto M0 = ygrad_grid_desc_m0_o_m1.GetLength(I0);
const auto O = ygrad_grid_desc_m0_o_m1.GetLength(I1);
const auto M1 = ygrad_grid_desc_m0_o_m1.GetLength(I2);
constexpr auto Y_O1 = AK1;
const auto Y_O0 = O / Y_O1;
const auto ygrad_grid_desc_o0_m_o1 = transform_tensor_descriptor(
ygrad_grid_desc_m0_o_m1,
make_tuple(make_unmerge_transform(make_tuple(Y_O0, Y_O1)),
make_merge_transform_v3_division_mod(make_tuple(M0, M1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return ygrad_grid_desc_o0_m_o1;
}
template <typename VGridDesc_N0_O_N1_>
__device__ static const auto
MakeVGridDesc_O0_N_O1(const VGridDesc_N0_O_N1_& v_grid_desc_n0_o_n1)
{
const auto N0 = v_grid_desc_n0_o_n1.GetLength(I0);
const auto O = v_grid_desc_n0_o_n1.GetLength(I1);
const auto N1 = v_grid_desc_n0_o_n1.GetLength(I2);
constexpr auto V_O1 = BK1;
const auto V_O0 = O / V_O1;
const auto v_grid_desc_o0_n_o1 = transform_tensor_descriptor(
v_grid_desc_n0_o_n1,
make_tuple(make_unmerge_transform(make_tuple(V_O0, V_O1)),
make_merge_transform_v3_division_mod(make_tuple(N0, N1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return v_grid_desc_o0_n_o1;
}
};
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -525,7 +574,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -525,7 +574,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_a_grid, __device__ static void Run(const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
const DataType* __restrict__ p_b1_grid, const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid, const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const DataType* __restrict__ p_ygrad_grid,
...@@ -540,7 +589,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -540,7 +589,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
...@@ -553,8 +602,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -553,8 +602,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_v_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -784,7 +833,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -784,7 +833,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, DataType,
decltype(b1_grid_desc_bk0_n_bk1), decltype(v_grid_desc_n0_o_n1),
decltype(b1_block_desc_bk0_n_bk1), decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
...@@ -797,7 +846,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -797,7 +846,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
B1ThreadTransferSrcResetCoordinateAfterRun, B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b1_grid_desc_bk0_n_bk1, v_grid_desc_n0_o_n1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
b1_element_op, b1_element_op,
b1_block_desc_bk0_n_bk1, b1_block_desc_bk0_n_bk1,
...@@ -1298,6 +1347,83 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1298,6 +1347,83 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// tiled the same way // tiled the same way
// TODO ANT: dP Gemm can reuse first blockwise gemm and pipeline // TODO ANT: dP Gemm can reuse first blockwise gemm and pipeline
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// A matrix blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(ygrad_grid_desc_o0_m_o1),
decltype(a_block_desc_ak0_m_ak1), // reuse block buf
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(v_grid_desc_o0_n_o1),
decltype(b_block_desc_bk0_n_bk1), // reuse block buf
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{},
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto pgrad_blockwise_gemm = typename PGradGemmTile_M_N_O::BlockwiseGemm{};
auto pgrad_acc_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
make_multi_index(-v_grid_desc_o0_n_o1.GetLength(I0), NPerBlock, 0);
const index_t num_o_block_main_loop = __builtin_amdgcn_readfirstlane(
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock);
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc, FloatGemmAcc,
...@@ -1525,7 +1651,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1525,7 +1651,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
#endif #endif
// softmax // P_i: = softmax(S_i:)
blockwise_softmax.RunWithPreCalcStats(acc_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(acc_thread_buf, lse_thread_buf);
#if 0 #if 0
...@@ -1628,13 +1754,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1628,13 +1754,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
#endif #endif
}); // end gemm dV }); // end gemm dV
// atomic_add vgrad // atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_thread_copy_vgpr_to_global.Run(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
vgrad_acc_thread_buf, vgrad_acc_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf); vgrad_grid_buf);
// gemm dP
pgrad_acc_thread_buf.Clear();
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("j loop idx %d, tid %zd, clear dP[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
pgrad_acc_thread_buf[I0],
pgrad_acc_thread_buf[I1],
pgrad_acc_thread_buf[I2],
pgrad_acc_thread_buf[I3]);
}
#endif
block_sync_lds();
// assume size K == size O so has main block loop
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
ygrad_grid_desc_o0_m_o1,
a_block_desc_ak0_m_ak1, // reuse
pgrad_gemm_tile_ygrad_blockwise_copy,
vgrad_grid_buf,
a_block_buf, // reuse
a_block_slice_copy_step, // reuse
v_grid_desc_o0_n_o1,
b_block_desc_bk0_n_bk1, // reuse
pgrad_gemm_tile_v_blockwise_copy,
v_grid_buf,
b_block_buf, // reuse
b_block_slice_copy_step, // reuse
pgrad_blockwise_gemm,
pgrad_acc_thread_buf,
num_o_block_main_loop);
#if 1
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
pgrad_acc_thread_buf[I0],
pgrad_acc_thread_buf[I1],
pgrad_acc_thread_buf[I2],
pgrad_acc_thread_buf[I3]);
}
#endif
// move slice window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
a_block_reset_copy_step); // rewind K a_block_reset_copy_step); // rewind K
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1,
...@@ -1643,6 +1814,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1643,6 +1814,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
ygrad_block_reset_copy_step); // rewind M ygrad_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_block_slice_copy_step); // step N vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_block_slice_copy_step); // step N
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_o0_n_o1,
pgrad_gemm_tile_v_block_reset_copy_step); // rewind O and step N
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment