Commit 15f1d4ad authored by Anthony Chang's avatar Anthony Chang
Browse files

compute y dot dy

parent a3e487ca
......@@ -369,7 +369,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, n] = m
// ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
}
// calculate y & log-sum-exp beforehand
......
......@@ -35,7 +35,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1,
......@@ -65,7 +65,7 @@ __global__ void
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m,
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
......@@ -329,6 +329,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
......@@ -372,31 +374,21 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Sequence<padder.PadN, padder.PadO>{});
}
template <typename YGridDesc_M_O, typename Number>
static auto MakeYGradGridDescriptor_M0_O_M1(const YGridDesc_M_O& ygrad_grid_desc_m_o,
const Number& M1)
template <typename YGridDesc_M_O>
static auto MakeYGradGridDescriptor_M0_O_M1(const YGridDesc_M_O& ygrad_grid_desc_m_o)
{
const auto M = ygrad_grid_desc_m_o.GetLength(I0);
const auto O = ygrad_grid_desc_m_o.GetLength(I1);
const auto M0 = M / M1;
const auto Y_M0 = M / Y_M1;
return transform_tensor_descriptor(
ygrad_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(M0, M1)), make_pass_through_transform(O)),
make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)), make_pass_through_transform(O)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// can we construct YGrad_m0_o_m1 from Y_m_o?
// static auto MakeYGradGridDescriptor_M0_O_M1(const std::vector<index_t>&
// y_gs_ms_os_lengths_vec,
// const std::vector<index_t>&
// y_gs_ms_os_strides_vec)
// {
// }
//
// dP = dY * V^T
//
......@@ -410,6 +402,61 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<Y_O1>{});
}
// V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new];
}
const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
// static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock()
//
// dQ = alpha * dS * K
//
......@@ -460,7 +507,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
......@@ -468,8 +515,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using VGradGridDesc_N_O = decltype(MakeVGradGridDescriptor_N_O({}, {}));
using YGradGridDesc_M0_O_M1 =
decltype(MakeYGradGridDescriptor_M0_O_M1(CGridDesc_M_N{}, Number<Y_M1>{}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
constexpr static auto make_MaskOutPredicate()
{
......@@ -547,7 +593,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
YGridDesc_M_O,
LSEGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
......@@ -647,7 +693,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
// dV = P^T * dY
......@@ -655,7 +701,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
/* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_{
DeviceOp::MakeYGradGridDescriptor_M0_O_M1(c_grid_desc_m_n_, Number<Y_M1>{})},
DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
// batch offsets
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
......@@ -665,8 +711,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
......@@ -704,12 +750,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
y_grid_desc_m_o_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
y_grid_desc_m_o_);
}
Print();
}
......@@ -754,14 +800,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
VGradGridDesc_N_O vgrad_grid_desc_n_o_;
YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1_;
......@@ -803,7 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_) * arg.batch_count_;
std::cout << "grid size = " << grid_size << '\n';
// Gemm0_K
const auto K =
......@@ -824,7 +870,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::VGradGridDesc_N_O,
DeviceOp::YGradGridDesc_M0_O_M1,
......@@ -855,7 +901,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
arg.vgrad_grid_desc_n_o_,
arg.ygrad_grid_desc_m0_o_m1_,
......@@ -909,8 +955,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
......@@ -960,7 +1006,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.y_grid_desc_m_o_,
arg.block_2_ctile_map_);
}
......
......@@ -211,6 +211,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
// PGrad Gemm
struct PGradGemmTile_M_N_O_
{
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O = Number<BlockSliceLength_O_ / SrcScalarPerVetor>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
// static_assert(BlockSliceLength_O_ % SrcScalarPerVetor == 0, "");
// static_assert(BlockSize_ % ThreadClusterLength_O == 0, "");
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
FloatGemmAcc,
BlockSize_,
Sequence<ThreadClusterLength_M, ThreadClusterLength_O>,
Sequence<0, 1>,
reduce::Add,
false>; // propagateNaN
// using ThreadReduceSrcDesc_M_O = decltype(make_naive_tensor_descriptor_packed(
// make_tuple(ThreadSliceLength_M, ThreadSliceLength_O)));
// using ThreadReduceDstDesc_M =
// decltype(make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceLength_M)));
// using ThreadwiseSumReduce =
// ThreadwiseReduction<FloatGemmAcc,
// ThreadReduceSrcDesc_M_O,
// ThreadReduceDstDesc_M,
// reduce::Add,
// false>; // propagateNaN
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// QGrad Gemm
// KGrad Gemm
......@@ -402,14 +452,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto MBlock = M / MPerBlock;
const auto NBlock = N / Gemm1NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
__host__ __device__ static constexpr auto
......@@ -437,7 +487,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
......@@ -497,7 +547,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__device__ static void Run(const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid,
const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
......@@ -512,8 +562,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
......@@ -526,8 +576,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -535,14 +585,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize());
// divide block work by [M, N]
// divide block work by [M, O]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I0),
y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2))))
{
return;
}
......@@ -1217,7 +1267,175 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(NPerBlock))[I0]);
}
#endif
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
//
// dP
//
constexpr auto y_thread_desc_m0_m1_o0_o1 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
YDotYGrad_M_O::ThreadSliceLength_M,
I1,
YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVetor, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(
y_grid_desc_mblock_mperblock_oblock_operblock, y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared), MPerBlock);
constexpr auto y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor(make_tuple(I1, P_M0, P_M1, P_M2),
make_tuple(P_M0 * P_M1 * P_M2, P_M1 * P_M2, P_M2, I1));
// y_dot_ygrad thread buffer for calculating sgrad; reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is tiled the same way
constexpr auto y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl =
lse_thread_desc_mblock_mrepeat_mwave_mperxdl;
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl),
decltype(y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds before accum\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(y_dot_ygrad_block_accum_buf.p_data_, MPerBlock);
}
#endif
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
//
// calculate y dot ygrad
//
index_t oblock_idx = 0;
do
{
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
ygrad_thread_buf);
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<offset>{}] * ygrad_thread_buf[Number<offset>{}];
});
});
#if 0
if (hipThreadIdx_x % 32 < 4 && hipBlockIdx_x == 0)
{
printf("bid %zd tid %zd, oblock_idx %d, y_thread_buf[0:3] = %f %f %f %f, ygrad_thread_buf[0:3] = %f %f %f %f\n",
hipBlockIdx_x,
hipThreadIdx_x,
oblock_idx,
(float)y_thread_buf[I0],
(float)y_thread_buf[I1],
(float)y_thread_buf[I2],
(float)y_thread_buf[I3],
(float)ygrad_thread_buf[I0],
(float)ygrad_thread_buf[I1],
(float)ygrad_thread_buf[I2],
(float)ygrad_thread_buf[I3]);
}
#endif
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(0, 0, 1, 0));
oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2));
// blockwise reduction using atomic_add
block_sync_lds();
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]);
});
block_sync_lds();
#if 1
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds after accum\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(y_dot_ygrad_block_accum_buf.p_data_, MPerBlock);
}
#endif
// distribute to threads
y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
y_dot_ygrad_block_accum_buf,
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
#if 0
if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4)
{
printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
y_thread_data_on_grid_idx[I0],
y_thread_data_on_grid_idx[I1],
y_thread_data_on_grid_idx[I2],
y_thread_data_on_grid_idx[I3]);
}
#endif
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf,
......@@ -1348,6 +1566,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
SubThreadBlock<BlockSize> p_thread_copy_subgroup(blockwise_gemm.GetWaveIdx()[I0],
blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, "");
vgrad_acc_thread_buf.Clear();
......@@ -1450,7 +1669,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT:
// shuffle dQ and write
if constexpr(false)
#if 0
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
......@@ -1646,6 +1865,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
});
}
#endif
}
};
......
......@@ -143,6 +143,16 @@ struct DynamicBuffer
}
}
__host__ __device__ void Clear()
{
static_assert(GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong! only local data share is supported");
for(index_t i = get_thread_local_1d_id(); i < element_space_size_; i += get_block_size())
{
Set(i, true, T{0});
}
}
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
......@@ -302,7 +312,9 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
static_assert(GetAddressSpace() == AddressSpaceEnum::Global ||
GetAddressSpace() == AddressSpaceEnum::Lds,
"only support global mem or local data share");
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
......@@ -319,7 +331,7 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(use_amd_buffer_addressing)
if constexpr(use_amd_buffer_addressing && GetAddressSpace() == AddressSpaceEnum::Global)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment