Commit 15f1d4ad authored by Anthony Chang's avatar Anthony Chang
Browse files

compute y dot dy

parent a3e487ca
...@@ -369,7 +369,8 @@ int run(int argc, char* argv[]) ...@@ -369,7 +369,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, n] = m // ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
} }
// calculate y & log-sum-exp beforehand // calculate y & log-sum-exp beforehand
......
...@@ -35,7 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,7 +35,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename VGradGridDescriptor_N_O, typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1, typename YGradGridDesc_M0_O_M1,
...@@ -65,7 +65,7 @@ __global__ void ...@@ -65,7 +65,7 @@ __global__ void
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args // const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
...@@ -329,6 +329,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -329,6 +329,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. // v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce // Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead // transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO; const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1 // 0, 1, .. NumDimG - 1
...@@ -372,31 +374,21 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -372,31 +374,21 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Sequence<padder.PadN, padder.PadO>{}); Sequence<padder.PadN, padder.PadO>{});
} }
template <typename YGridDesc_M_O, typename Number> template <typename YGridDesc_M_O>
static auto MakeYGradGridDescriptor_M0_O_M1(const YGridDesc_M_O& ygrad_grid_desc_m_o, static auto MakeYGradGridDescriptor_M0_O_M1(const YGridDesc_M_O& ygrad_grid_desc_m_o)
const Number& M1)
{ {
const auto M = ygrad_grid_desc_m_o.GetLength(I0); const auto M = ygrad_grid_desc_m_o.GetLength(I0);
const auto O = ygrad_grid_desc_m_o.GetLength(I1); const auto O = ygrad_grid_desc_m_o.GetLength(I1);
const auto M0 = M / M1; const auto Y_M0 = M / Y_M1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
ygrad_grid_desc_m_o, ygrad_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(M0, M1)), make_pass_through_transform(O)), make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)), make_pass_through_transform(O)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// can we construct YGrad_m0_o_m1 from Y_m_o?
// static auto MakeYGradGridDescriptor_M0_O_M1(const std::vector<index_t>&
// y_gs_ms_os_lengths_vec,
// const std::vector<index_t>&
// y_gs_ms_os_strides_vec)
// {
// }
// //
// dP = dY * V^T // dP = dY * V^T
// //
...@@ -410,6 +402,61 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -410,6 +402,61 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<Y_O1>{}); Number<Y_O1>{});
} }
// V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new];
}
const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
// static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock()
// //
// dQ = alpha * dS * K // dQ = alpha * dS * K
// //
...@@ -460,7 +507,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -460,7 +507,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
...@@ -468,8 +515,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -468,8 +515,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using VGradGridDesc_N_O = decltype(MakeVGradGridDescriptor_N_O({}, {})); using VGradGridDesc_N_O = decltype(MakeVGradGridDescriptor_N_O({}, {}));
using YGradGridDesc_M0_O_M1 = using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
decltype(MakeYGradGridDescriptor_M0_O_M1(CGridDesc_M_N{}, Number<Y_M1>{}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -547,7 +593,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -547,7 +593,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
CGridDesc_M_N, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -647,7 +693,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -647,7 +693,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
// dV = P^T * dY // dV = P^T * dY
...@@ -655,7 +701,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -655,7 +701,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
/* PTrans descriptor will be constructed in kernel */ /* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_{ ygrad_grid_desc_m0_o_m1_{
DeviceOp::MakeYGradGridDescriptor_M0_O_M1(c_grid_desc_m_n_, Number<Y_M1>{})}, DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
// batch offsets // batch offsets
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
...@@ -665,8 +711,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -665,8 +711,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
...@@ -704,12 +750,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -704,12 +750,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_, b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_, y_grid_desc_m_o_,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); y_grid_desc_m_o_);
} }
Print(); Print();
} }
...@@ -754,14 +800,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -754,14 +800,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; y_grid_desc_mblock_mperblock_oblock_operblock_;
VGradGridDesc_N_O vgrad_grid_desc_n_o_; VGradGridDesc_N_O vgrad_grid_desc_n_o_;
YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1_; YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1_;
...@@ -803,7 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -803,7 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_) * arg.batch_count_;
std::cout << "grid size = " << grid_size << '\n'; std::cout << "grid size = " << grid_size << '\n';
// Gemm0_K // Gemm0_K
const auto K = const auto K =
...@@ -824,7 +870,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -824,7 +870,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
DeviceOp::VGradGridDesc_N_O, DeviceOp::VGradGridDesc_N_O,
DeviceOp::YGradGridDesc_M0_O_M1, DeviceOp::YGradGridDesc_M0_O_M1,
...@@ -855,7 +901,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -855,7 +901,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.vgrad_grid_desc_n_o_, arg.vgrad_grid_desc_n_o_,
arg.ygrad_grid_desc_m0_o_m1_, arg.ygrad_grid_desc_m0_o_m1_,
...@@ -909,8 +955,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -909,8 +955,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
...@@ -960,7 +1006,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -960,7 +1006,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.y_grid_desc_m_o_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
......
...@@ -211,6 +211,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -211,6 +211,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later using VGradGemmTile_N_O_M = VGradGemmTile_N_O_M_<>; // tune later
// PGrad Gemm
struct PGradGemmTile_M_N_O_
{
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVetor = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O = Number<BlockSliceLength_O_ / SrcScalarPerVetor>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVetor>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
// static_assert(BlockSliceLength_O_ % SrcScalarPerVetor == 0, "");
// static_assert(BlockSize_ % ThreadClusterLength_O == 0, "");
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
FloatGemmAcc,
BlockSize_,
Sequence<ThreadClusterLength_M, ThreadClusterLength_O>,
Sequence<0, 1>,
reduce::Add,
false>; // propagateNaN
// using ThreadReduceSrcDesc_M_O = decltype(make_naive_tensor_descriptor_packed(
// make_tuple(ThreadSliceLength_M, ThreadSliceLength_O)));
// using ThreadReduceDstDesc_M =
// decltype(make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceLength_M)));
// using ThreadwiseSumReduce =
// ThreadwiseReduction<FloatGemmAcc,
// ThreadReduceSrcDesc_M_O,
// ThreadReduceDstDesc_M,
// reduce::Add,
// false>; // propagateNaN
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// QGrad Gemm // QGrad Gemm
// KGrad Gemm // KGrad Gemm
...@@ -402,14 +452,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -402,14 +452,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto NBlock = N / Gemm1NPerBlock; const auto NBlock = N / Gemm1NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock; return y_grid_desc_mblock_mperblock_oblock_operblock;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -437,7 +487,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -437,7 +487,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -497,7 +547,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -497,7 +547,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__device__ static void Run(const DataType* __restrict__ p_a_grid, __device__ static void Run(const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
const DataType* __restrict__ p_b1_grid, const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid, const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, DataType* __restrict__ p_qgrad_grid,
...@@ -512,8 +562,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -512,8 +562,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o, const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
...@@ -526,8 +576,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -526,8 +576,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize()); p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -535,14 +585,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -535,14 +585,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize()); p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, O]
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), make_tuple(y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2))))
{ {
return; return;
} }
...@@ -1217,7 +1267,175 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1217,7 +1267,175 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(NPerBlock))[I0]); make_multi_index(NPerBlock))[I0]);
} }
#endif #endif
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
//
// dP
//
constexpr auto y_thread_desc_m0_m1_o0_o1 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
YDotYGrad_M_O::ThreadSliceLength_M,
I1,
YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVetor, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(
y_grid_desc_mblock_mperblock_oblock_operblock, y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared), MPerBlock);
constexpr auto y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor(make_tuple(I1, P_M0, P_M1, P_M2),
make_tuple(P_M0 * P_M1 * P_M2, P_M1 * P_M2, P_M2, I1));
// y_dot_ygrad thread buffer for calculating sgrad; reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is tiled the same way
constexpr auto y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl =
lse_thread_desc_mblock_mrepeat_mwave_mperxdl;
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl),
decltype(y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds before accum\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(y_dot_ygrad_block_accum_buf.p_data_, MPerBlock);
}
#endif
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
//
// calculate y dot ygrad
//
index_t oblock_idx = 0;
do
{
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
ygrad_thread_buf);
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<offset>{}] * ygrad_thread_buf[Number<offset>{}];
});
});
#if 0
if (hipThreadIdx_x % 32 < 4 && hipBlockIdx_x == 0)
{
printf("bid %zd tid %zd, oblock_idx %d, y_thread_buf[0:3] = %f %f %f %f, ygrad_thread_buf[0:3] = %f %f %f %f\n",
hipBlockIdx_x,
hipThreadIdx_x,
oblock_idx,
(float)y_thread_buf[I0],
(float)y_thread_buf[I1],
(float)y_thread_buf[I2],
(float)y_thread_buf[I3],
(float)ygrad_thread_buf[I0],
(float)ygrad_thread_buf[I1],
(float)ygrad_thread_buf[I2],
(float)ygrad_thread_buf[I3]);
}
#endif
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(0, 0, 1, 0));
oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2));
// blockwise reduction using atomic_add
block_sync_lds();
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]);
});
block_sync_lds();
#if 1
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds after accum\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(y_dot_ygrad_block_accum_buf.p_data_, MPerBlock);
}
#endif
// distribute to threads
y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
y_dot_ygrad_block_accum_buf,
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
#if 0
if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4)
{
printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
y_thread_data_on_grid_idx[I0],
y_thread_data_on_grid_idx[I1],
y_thread_data_on_grid_idx[I2],
y_thread_data_on_grid_idx[I3]);
}
#endif
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf, lse_grid_buf,
...@@ -1348,6 +1566,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1348,6 +1566,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
SubThreadBlock<BlockSize> p_thread_copy_subgroup(blockwise_gemm.GetWaveIdx()[I0], SubThreadBlock<BlockSize> p_thread_copy_subgroup(blockwise_gemm.GetWaveIdx()[I0],
blockwise_gemm.GetWaveIdx()[I1]); blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, ""); static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, "");
vgrad_acc_thread_buf.Clear(); vgrad_acc_thread_buf.Clear();
...@@ -1450,7 +1669,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1450,7 +1669,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT: // TODO ANT:
// shuffle dQ and write // shuffle dQ and write
if constexpr(false) #if 0
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...@@ -1646,6 +1865,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1646,6 +1865,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}); });
} }
#endif
} }
}; };
......
...@@ -143,6 +143,16 @@ struct DynamicBuffer ...@@ -143,6 +143,16 @@ struct DynamicBuffer
} }
} }
__host__ __device__ void Clear()
{
static_assert(GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong! only local data share is supported");
for(index_t i = get_thread_local_1d_id(); i < element_space_size_; i += get_block_size())
{
Set(i, true, T{0});
}
}
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value,
...@@ -302,7 +312,9 @@ struct DynamicBuffer ...@@ -302,7 +312,9 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T"); "wrong! X should contain multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); static_assert(GetAddressSpace() == AddressSpaceEnum::Global ||
GetAddressSpace() == AddressSpaceEnum::Lds,
"only support global mem or local data share");
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing = bool constexpr use_amd_buffer_addressing =
...@@ -319,7 +331,7 @@ struct DynamicBuffer ...@@ -319,7 +331,7 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing && GetAddressSpace() == AddressSpaceEnum::Global)
{ {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment