Commit 71f5dd1d authored by aska-0096's avatar aska-0096
Browse files

fix bug

parent 6c97a1e2
...@@ -22,7 +22,6 @@ include(TargetFlags) ...@@ -22,7 +22,6 @@ include(TargetFlags)
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
...@@ -30,12 +29,6 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -30,12 +29,6 @@ if(USE_BITINT_EXTENSION_INT4)
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif() endif()
if(USE_OPT_NAVI3X)
add_compile_options(-mcumode)
add_compile_options(-mno-wavefrontsize64)
message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}")
endif()
## Threads ## Threads
set(THREADS_PREFER_PTHREAD_FLAG ON) set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
......
...@@ -74,8 +74,8 @@ using DeviceConvFwdInstance = ...@@ -74,8 +74,8 @@ using DeviceConvFwdInstance =
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
1, 4,
1, 2,
S<1, 32, 1, 8>, S<1, 32, 1, 8>,
8>; 8>;
......
...@@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto b_block_desc_k0perblock_nperblock_k1 = constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
...@@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto b_block_space_size_aligned = math::integer_least_multiple( constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(ADataType) + constexpr auto c_block_space_size_aligned = math::integer_least_multiple(
b_block_space_size_aligned * sizeof(BDataType)); cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(),
max_lds_align);
return math::max((a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType)),
c_block_space_size_aligned * sizeof(CShuffleDataType));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment