Commit 9e1091cd authored by aska-0096's avatar aska-0096
Browse files

multiple fix, try ait compile

parent f677f702
......@@ -418,13 +418,14 @@ struct BlockwiseGemmWMMA
}
protected:
// A[K0, M0, M1, M2, K1]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}));
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{}, Number<WmmaK>{}, Number<A_K1>{}, Number<A_K1>{}, Number<1>{}));
// B[K0, N0, N1, N2, K1]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{}, Number<WmmaK>{}, Number<B_K1>{}, Number<B_K1>{}, Number<1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......
......@@ -136,8 +136,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
// Bug, MNK vector load check not implemented correctly
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
......@@ -725,6 +725,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Batch Offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
// for checking vector load/store
// index_t MRaw_;
// index_t NRaw_;
// index_t KRaw_;
};
// Invoker
......
......@@ -98,9 +98,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
// Bug: gemm.TileDesc(64, 32, 64, 64, 8, 0, 16, 16, 2, 2), failed
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
......@@ -108,7 +106,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory
// Describe how data read from Global memory
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
{
......
......@@ -92,8 +92,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
static constexpr auto BEnableLds_manu = true;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
......
......@@ -170,12 +170,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment