Commit 9e1091cd authored by aska-0096's avatar aska-0096
Browse files

multiple fix, try ait compile

parent f677f702
...@@ -418,13 +418,14 @@ struct BlockwiseGemmWMMA ...@@ -418,13 +418,14 @@ struct BlockwiseGemmWMMA
} }
protected: protected:
// A[K0, M0, M1, M2, K1] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}),
make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{})); make_tuple(Number<A_K1>{}, Number<WmmaK>{}, Number<A_K1>{}, Number<A_K1>{}, Number<1>{}));
// B[K0, N0, N1, N2, K1] // B[K0, N0, N1, N2, K1]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{})); make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{}, Number<WmmaK>{}, Number<B_K1>{}, Number<B_K1>{}, Number<1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......
...@@ -136,8 +136,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -136,8 +136,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; // Bug, MNK vector load check not implemented correctly
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu; static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
...@@ -725,6 +725,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -725,6 +725,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Batch Offset // Batch Offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
// for checking vector load/store
// index_t MRaw_;
// index_t NRaw_;
// index_t KRaw_;
}; };
// Invoker // Invoker
......
...@@ -98,9 +98,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -98,9 +98,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = true;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
// Bug: gemm.TileDesc(64, 32, 64, 64, 8, 0, 16, 16, 2, 2), failed
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu; static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
...@@ -108,7 +106,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -108,7 +106,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory
// Describe how data read from Global memory // Describe how data read from Global memory
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
......
...@@ -92,8 +92,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -92,8 +92,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure static constexpr auto BEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu; static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu; static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
......
...@@ -170,12 +170,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -170,12 +170,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true; static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static constexpr auto conv_to_gemm_transformer = static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{}; TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment