"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "a63f66f735f229d5f52523a88acf632485e51d99"
Commit c5fd087e authored by aska-0096's avatar aska-0096
Browse files

Attn, skip b lds

parent 6e28a8ac
...@@ -180,27 +180,57 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -180,27 +180,57 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec, static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec) const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( if constexpr(B0EnableLds)
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec), {
Number<K1>{}); return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<K1>{});
}
else
{
return Transform::MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<WmmaK>{},
Number<LRepeat>{},
Number<LWaves>{},
Number<LPerWmma>{},
Number<K1>{});
}
} }
static auto MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec, static auto MakeB1GridDescriptor(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec) const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( if constexpr(B1EnableLds)
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, b1_gs_ns_ls_strides_vec), {
Number<L1>{}); return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Number<L1>{});
}
else
{
return Transform::MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Number<WmmaK>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWmma>{},
Number<L1>{});
}
} }
using AGridDesc = decltype(MakeAGridDescriptor({}, {})); using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor({}, {})); using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {})); using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor // InMemory Data Descriptor
AGridDesc, AGridDesc,
B0GridDesc_BK0_L_BK1, B0GridDesc,
B1GridDesc_BL0_N_BL1, B1GridDesc,
CGridDesc_M_N, CGridDesc_M_N,
// Tiling Family // Tiling Family
MPerBlock, MPerBlock,
...@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{ b0_grid_desc{
DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1( b1_grid_desc{
b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{ c_grid_desc_m_n_{
Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
...@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore = acc1_biases_gs_ms_ns_lengths; ignore = acc1_biases_gs_ms_ns_lengths;
ignore = acc1_biases_gs_ms_ns_strides; ignore = acc1_biases_gs_ms_ns_strides;
if(GridwiseOp::CheckValidity(a_grid_desc, if(GridwiseOp::CheckValidity(
b0_grid_desc_bk0_l_bk1_, a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_))
b1_grid_desc_bl0_n_bl1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Tensor Descriptors // Tensor Descriptors
AGridDesc a_grid_desc; AGridDesc a_grid_desc;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_; B0GridDesc b0_grid_desc;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_; B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
...@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1DataType, B1DataType,
CDataType, CDataType,
DeviceOp::AGridDesc, DeviceOp::AGridDesc,
DeviceOp::B0GridDesc_BK0_L_BK1, DeviceOp::B0GridDesc,
DeviceOp::B1GridDesc_BL0_N_BL1, DeviceOp::B1GridDesc,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
B0ElementwiseOperation, B0ElementwiseOperation,
...@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc, arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_, arg.b0_grid_desc,
arg.b1_grid_desc_bl0_n_bl1_, arg.b1_grid_desc,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b0_element_op_, arg.b0_element_op_,
...@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
if(!GridwiseOp::CheckValidity(arg.a_grid_desc, if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_, arg.b0_grid_desc,
arg.b1_grid_desc_bl0_n_bl1_, arg.b1_grid_desc,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
{ {
......
...@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize()); b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical // Limitation: NumDim of Src and Dst descriptor should be identical
......
...@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template <typename BGridDesc_L_K,
typename WmmaK,
typename LRepeat,
typename LWaves,
typename LPerWmma,
typename BK1>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1(
const BGridDesc_L_K& b_grid_desc_l_k,
const WmmaK&,
const LRepeat&,
const LWaves&,
const LPerWmma&,
const BK1&)
{
const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
const auto K = b_grid_desc_l_k.GetLength(I1);
const auto BKWmma = K / WmmaK{};
constexpr auto BKRow = WmmaK{} / BK1{};
return transform_tensor_descriptor(
b_grid_desc_l_k,
make_tuple(make_unmerge_transform(make_tuple(BKWmma, BKRow, BK1{})),
make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
// //
// B1 // B1
// //
...@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template <typename BGridDesc_N_L,
typename WmmaL,
typename NRepeat,
typename NWaves,
typename NPerWmma,
typename BL1>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1(
const BGridDesc_N_L& b_grid_desc_n_l,
const WmmaL&,
const NRepeat&,
const NWaves&,
const NPerWmma&,
const BL1&)
{
const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
const auto L = b_grid_desc_n_l.GetLength(I1);
const auto BLWmma = L / WmmaL{};
constexpr auto BLRow = WmmaL{} / BL1{};
return transform_tensor_descriptor(
b_grid_desc_n_l,
make_tuple(make_unmerge_transform(make_tuple(BLWmma, BLRow, BL1{})),
make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
// //
// C // C
// //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment