Commit 7a302cc9 authored by Anthony Chang's avatar Anthony Chang
Browse files

debugging dQ; suspected K mat not properly loaded

parent b637c77d
......@@ -104,7 +104,7 @@ using DeviceGemmInstance =
TensorSpecY,
1,
256,
256, // MPerBlock
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
......@@ -114,7 +114,7 @@ using DeviceGemmInstance =
2, // B1K1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
......@@ -375,18 +375,34 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break;
default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
// calculate y & log-sum-exp beforehand
......
......@@ -28,8 +28,8 @@ template <typename DataType,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1,
typename VGridDesc_N0_O_N1,
typename CGridDesc_M_N,
typename LSEGridDesc_M,
......@@ -335,8 +335,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
CheckValidity(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
......@@ -345,9 +345,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1);
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
......@@ -446,7 +446,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
// PGrad Gemm has the same layout as P Gemm (A row-major B col-major)
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
private:
......@@ -521,6 +521,82 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
struct QGradGemmTile_M_K_N
{
template <typename QGridDesc_K0_M_K1_>
__device__ static const auto
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1)
{
const auto K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto K1 = q_grid_desc_k0_m_k1.GetLength(I2);
const auto K = K0 * K1;
const auto MBlock = M / MPerBlock;
const auto KBlock = K / Gemm1NPerBlock; // NOTE: QGrad gemm is similar to Y gemm
const auto q_grid_desc_m_k = transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(M),
make_merge_transform_v3_division_mod(make_tuple(K0, K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
q_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(KBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
template <typename SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_>
__device__ static const auto
MakeSGradThreadDesc_N0_M_N1(const SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_& sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{
constexpr auto m0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto sgrad_thread_desc_n0_m_n1 = transform_tensor_descriptor(
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return sgrad_thread_desc_n0_m_n1;
}
template <typename KGridDesc_K0_N_K1_>
__device__ static const auto
MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
{
const auto K_K0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K_K1 = k_grid_desc_k0_n_k1.GetLength(I2);
constexpr auto K_N1 = BK1;
const auto K_N0 = N / K_N1;
const auto k_grid_desc_n0_k_n1 = transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(make_tuple(K_N0, K_N1)),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return k_grid_desc_n0_k_n1;
}
};
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
......@@ -572,8 +648,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
__device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid,
const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
......@@ -587,8 +663,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const AccElementwiseOperation& acc_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
......@@ -598,10 +674,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_k_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_v_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -612,6 +688,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
// divide block work by [M, O]
const auto block_work_idx =
......@@ -653,7 +731,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(q_grid_desc_k0_m_k1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
......@@ -666,7 +744,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
......@@ -684,7 +762,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(k_grid_desc_k0_n_k1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
......@@ -697,7 +775,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
b_element_op,
b_block_desc_bk0_n_bk1,
......@@ -746,9 +824,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto a_block_reset_copy_step =
make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
......@@ -757,11 +835,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) /
KPerBlock);
//
// set up Gemm1
// set up O / dQ Gemm
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
......@@ -811,47 +889,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
DataType,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
// FloatGemmAcc,
// DataType,
// decltype(acc_thread_desc_k0_m_k1),
// decltype(a1_thread_desc_k0_m_k1),
// tensor_operation::element_wise::PassThrough,
// Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
// Sequence<1, 0, 2>,
// 2,
// n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
decltype(v_grid_desc_n0_o_n1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>(
v_grid_desc_n0_o_n1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// auto b1_blockwise_copy =
// ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
// BElementwiseOperation,
// tensor_operation::element_wise::PassThrough,
// InMemoryDataOperationEnum::Set,
// Sequence<B1K0, Gemm1NPerBlock, B1K1>,
// B1BlockTransferThreadClusterLengths_BK0_N_BK1,
// B1BlockTransferThreadClusterArrangeOrder,
// DataType,
// DataType,
// decltype(v_grid_desc_n0_o_n1),
// decltype(b1_block_desc_bk0_n_bk1),
// B1BlockTransferSrcAccessOrder,
// Sequence<1, 0, 2>,
// B1BlockTransferSrcVectorDim,
// 2,
// B1BlockTransferSrcScalarPerVector,
// B1BlockTransferDstScalarPerVector_BK1,
// 1,
// 1,
// B1ThreadTransferSrcResetCoordinateAfterRun,
// true, // DstResetCoord
// NumGemmKPrefetchStage>(
// v_grid_desc_n0_o_n1,
// make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
// b1_element_op,
// b1_block_desc_bk0_n_bk1,
// make_multi_index(0, 0, 0),
// tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
......@@ -944,7 +1022,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
decltype(thread_slice_desc_m_n)>{};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize C
......@@ -1178,7 +1256,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
VGradGemmTile_N_O_M::GemmMPack,
true>{}; // TranspossC
auto vgrad_acc_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
......@@ -1351,7 +1429,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// A matrix blockwise copy
// dP Gemm A matrix blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
......@@ -1382,7 +1460,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
// dP Gemm B matrix blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
......@@ -1454,6 +1532,81 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
//
// dQ
//
const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
auto qgrad_grid_desc_mblock_mperblock_kblock_kperblock =
QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
q_grid_desc_k0_m_k1);
// dQ Gemm A matrix blockwise copy
auto qgrad_gemm_tile_sgrad_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
DataType,
decltype(acc_thread_desc_k0_m_k1), // reuse desc
decltype(a1_thread_desc_k0_m_k1), // reuse desc
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// dQ Gemm B matrix blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>, // reuse from V
B1BlockTransferThreadClusterLengths_BK0_N_BK1, // reuse from V
B1BlockTransferThreadClusterArrangeOrder, // reuse from V
DataType,
DataType,
decltype(k_grid_desc_n0_k_n1),
decltype(b1_block_desc_bk0_n_bk1), // reuse from V
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>(
k_grid_desc_n0_k_n1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto qgrad_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<DataType, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
//
// calculate y dot ygrad
//
......@@ -1528,7 +1681,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
#if 0
#if 1
if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4)
{
printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n",
......@@ -1547,6 +1700,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(I0, I0, I0, I0),
lse_thread_buf);
// Initialize dQ
qgrad_thread_buf.Clear();
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
......@@ -1559,16 +1716,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
continue;
}
// gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(q_grid_desc_k0_m_k1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
q_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
k_grid_desc_k0_n_k1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
k_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
......@@ -1673,7 +1830,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, "");
vgrad_acc_thread_buf.Clear();
vgrad_thread_buf.Clear();
// TODO ANT: single buffer prefetch pipeline
static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV
......@@ -1736,7 +1893,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
#endif
block_sync_lds(); // sync before read
vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_acc_thread_buf);
vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_thread_buf);
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
......@@ -1745,10 +1902,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
gemm1_k_block_outer_index,
vgrad_gemm_loop_idx.value,
hipThreadIdx_x,
vgrad_acc_thread_buf[I0],
vgrad_acc_thread_buf[I1],
vgrad_acc_thread_buf[I2],
vgrad_acc_thread_buf[I3]);
vgrad_thread_buf[I0],
vgrad_thread_buf[I1],
vgrad_thread_buf[I2],
vgrad_thread_buf[I3]);
}
#endif
}); // end gemm dV
......@@ -1756,7 +1913,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
vgrad_acc_thread_buf,
vgrad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf);
......@@ -1779,10 +1936,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_blockwise_gemm,
pgrad_thread_buf,
num_o_block_main_loop);
#if 0
#if 1
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
printf("outer j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
pgrad_thread_buf[I0],
......@@ -1806,13 +1963,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I1];
// dS and P has same thread buf layout
sgrad_thread_buf(i) =
acc_thread_buf[i] * (pgrad_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}]);
acc_thread_buf[i] * (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
});
#if 0
#if 1
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n",
printf("outer j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
sgrad_thread_buf[I0],
......@@ -1822,10 +1979,90 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
#endif
// gemm dQ
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read
qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds dQ gemm K matrix\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(b1_block_buf.p_data_,
(index_t)b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
}
#endif
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
qgrad_gemm_tile_sgrad_blockwise_copy.Run(
acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
sgrad_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("inner j loop idx %d, tid %zd, dS downcast[0:3] = %f, %f, %f, %f\n",
i.value,
hipThreadIdx_x,
(float)a1_thread_buf[I0],
(float)a1_thread_buf[I1],
(float)a1_thread_buf[I2],
(float)a1_thread_buf[I3]);
}
#endif
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
block_sync_lds();
qgrad_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, qgrad_thread_buf);
block_sync_lds();
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
b1_block_slice_copy_step);
qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
});
}
// tail
{
qgrad_gemm_tile_sgrad_blockwise_copy.Run(
acc_thread_desc_k0_m_k1,
make_tuple(
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
sgrad_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
qgrad_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, qgrad_thread_buf);
}
} // end gemm dQ
// move slice window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
a_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_k0_m_k1,
a_block_reset_copy_step); // rewind K
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1,
b_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_k0_n_k1,
b_block_reset_copy_step); // rewind K and step N
ygrad_blockwise_copy.MoveSrcSliceWindow(ygrad_grid_desc_m0_o_m1,
ygrad_block_reset_copy_step); // rewind M
......@@ -1841,7 +2078,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT:
// shuffle dQ and write
#if 0
#if 1
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("tid %zd, dQ[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
qgrad_thread_buf[I0],
qgrad_thread_buf[I1],
qgrad_thread_buf[I2],
qgrad_thread_buf[I3]);
}
#endif
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
......@@ -1968,7 +2217,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatCShuffle, // typename SrcData,
DataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
......@@ -1976,7 +2225,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
......@@ -2013,7 +2262,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
qgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
......@@ -2024,20 +2273,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
qgrad_grid_buf);
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds dQ shuffle loop %d\n", access_id.value);
if(hipBlockIdx_x == 1)
{
debug::print_shared(c_shuffle_block_buf.p_data_,
(index_t)c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
}
#endif
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
qgrad_grid_desc_mblock_mperblock_kblock_kperblock, c_global_step);
}
});
}
#endif
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment