Commit 6b041227 authored by guangzlu's avatar guangzlu
Browse files

modified device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp

parent 3b58c3ec
...@@ -140,6 +140,7 @@ template <index_t NumDimG, ...@@ -140,6 +140,7 @@ template <index_t NumDimG,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
...@@ -207,6 +208,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -207,6 +208,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -246,6 +248,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -246,6 +248,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -295,6 +298,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -295,6 +298,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
Number<B1K1>{}); Number<B1K1>{});
} }
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
{ {
const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -325,10 +334,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -325,10 +334,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -408,6 +420,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -408,6 +420,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
ZGridDesc_M_N,
LSEGridDesc_M, LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -465,6 +478,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -465,6 +478,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
ZDataType* p_z_grid;
LSEDataType* p_lse_grid_; LSEDataType* p_lse_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -473,6 +487,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -473,6 +487,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
// batch & stride // batch & stride
...@@ -511,6 +526,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -511,6 +526,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<std::vector<const void*>> p_acc1_biases_vec,
...@@ -550,6 +566,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -550,6 +566,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]); const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]); const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]); const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -562,6 +579,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -562,6 +579,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N( const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_os_lengths, problem_desc.z_gs_ms_os_strides);
const auto lse_grid_desc_m = const auto lse_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]); DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
...@@ -573,6 +592,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -573,6 +592,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment