Commit c5fd087e authored by aska-0096's avatar aska-0096
Browse files

Attn, skip b lds

parent 6e28a8ac
......@@ -180,27 +180,57 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec),
Number<K1>{});
if constexpr(B0EnableLds)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<K1>{});
}
else
{
return Transform::MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<WmmaK>{},
Number<LRepeat>{},
Number<LWaves>{},
Number<LPerWmma>{},
Number<K1>{});
}
}
static auto MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
static auto MakeB1GridDescriptor(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, b1_gs_ns_ls_strides_vec),
Number<L1>{});
if constexpr(B1EnableLds)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Number<L1>{});
}
else
{
return Transform::MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Number<WmmaK>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWmma>{},
Number<L1>{});
}
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
......@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc_BK0_L_BK1,
B1GridDesc_BL0_N_BL1,
B0GridDesc,
B1GridDesc,
CGridDesc_M_N,
// Tiling Family
MPerBlock,
......@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{
b0_grid_desc{
DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(
b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
b1_grid_desc{
DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{
Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
a_grid_desc_g_m_k_{
......@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore = acc1_biases_gs_ms_ns_lengths;
ignore = acc1_biases_gs_ms_ns_strides;
if(GridwiseOp::CheckValidity(a_grid_desc,
b0_grid_desc_bk0_l_bk1_,
b1_grid_desc_bl0_n_bl1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
if(GridwiseOp::CheckValidity(
a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Tensor Descriptors
AGridDesc a_grid_desc;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
B0GridDesc b0_grid_desc;
B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
......@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1DataType,
CDataType,
DeviceOp::AGridDesc,
DeviceOp::B0GridDesc_BK0_L_BK1,
DeviceOp::B1GridDesc_BL0_N_BL1,
DeviceOp::B0GridDesc,
DeviceOp::B1GridDesc,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
B0ElementwiseOperation,
......@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg.p_b1_grid_,
arg.p_c_grid_,
arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b0_element_op_,
......@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
......
......@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
......
......@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BGridDesc_L_K,
typename WmmaK,
typename LRepeat,
typename LWaves,
typename LPerWmma,
typename BK1>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1(
const BGridDesc_L_K& b_grid_desc_l_k,
const WmmaK&,
const LRepeat&,
const LWaves&,
const LPerWmma&,
const BK1&)
{
const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
const auto K = b_grid_desc_l_k.GetLength(I1);
const auto BKWmma = K / WmmaK{};
constexpr auto BKRow = WmmaK{} / BK1{};
return transform_tensor_descriptor(
b_grid_desc_l_k,
make_tuple(make_unmerge_transform(make_tuple(BKWmma, BKRow, BK1{})),
make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
//
// B1
//
......@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BGridDesc_N_L,
typename WmmaL,
typename NRepeat,
typename NWaves,
typename NPerWmma,
typename BL1>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1(
const BGridDesc_N_L& b_grid_desc_n_l,
const WmmaL&,
const NRepeat&,
const NWaves&,
const NPerWmma&,
const BL1&)
{
const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
const auto L = b_grid_desc_n_l.GetLength(I1);
const auto BLWmma = L / WmmaL{};
constexpr auto BLRow = WmmaL{} / BL1{};
return transform_tensor_descriptor(
b_grid_desc_n_l,
make_tuple(make_unmerge_transform(make_tuple(BLWmma, BLRow, BL1{})),
make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
//
// C
//
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment