Commit 5cb59d36 authored by Jing Zhang's avatar Jing Zhang
Browse files

resolve conflicts

parents 7e3a5613 7e147c64
......@@ -113,7 +113,7 @@ message("checking which targets are supported")
#Setting GPU_TARGETS on command line will override this list
if(NOT PROFILER_ONLY)
rocm_check_target_ids(DEFAULT_GPU_TARGETS
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200")
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
else()
add_definitions(-DPROFILER_ONLY)
set(GPU_TARGETS "" CACHE STRING "" FORCE)
......@@ -129,7 +129,7 @@ else()
elseif(GPU_ARCH MATCHES "gfx11")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
elseif(GPU_ARCH MATCHES "gfx12")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
else()
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
endif()
......
......@@ -496,7 +496,7 @@ def Build_CK(Map conf=[:]){
def navi_node = 0
def mi300_node = 0
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel-internal') {
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel-internal') {
try {
(retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) {
......@@ -602,7 +602,7 @@ def process_results(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel-internal') {
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel-internal') {
try {
(retimage, image) = getDockerImage(conf)
}
......
......@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
#-Werror
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
......
......@@ -45,3 +45,5 @@ for sphinx_var in ROCmDocs.SPHINX_VARS:
extensions += ['sphinxcontrib.bibtex']
bibtex_bibfiles = ['refs.bib']
cpp_id_attributes = ["__global__", "__device__", "__host__"]
......@@ -64,31 +64,31 @@ Advanced examples:
Layout
-------------------------------------
.. doxygenstruct:: ck::wrapper::Layout
.. doxygenstruct:: Layout
-------------------------------------
Layout helpers
-------------------------------------
.. doxygenfile:: layout_utils.hpp
.. doxygenfile:: include/ck/wrapper/utils/layout_utils.hpp
-------------------------------------
Tensor
-------------------------------------
.. doxygenstruct:: ck::wrapper::Tensor
.. doxygenstruct:: Tensor
-------------------------------------
Tensor helpers
-------------------------------------
.. doxygenfile:: tensor_utils.hpp
.. doxygenfile:: include/ck/wrapper/utils/tensor_utils.hpp
.. doxygenfile:: tensor_partition.hpp
.. doxygenfile:: include/ck/wrapper/utils/tensor_partition.hpp
-------------------------------------
Operations
-------------------------------------
.. doxygenfile:: copy.hpp
.. doxygenfile:: gemm.hpp
.. doxygenfile:: include/ck/wrapper/operations/copy.hpp
.. doxygenfile:: include/ck/wrapper/operations/gemm.hpp
......@@ -27,7 +27,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102" OR GPU_TARGETS MATCHES "gfx1200")
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
......@@ -53,12 +54,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
......@@ -71,3 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
......@@ -155,12 +155,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 3:
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
......
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102 gfx1200)
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
......@@ -6,7 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
set(target 1)
endif()
endforeach()
......
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1200)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
......
......@@ -58,7 +58,7 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#if defined(__gfx1200__)
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
......@@ -104,13 +104,6 @@
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
......
......@@ -85,6 +85,9 @@ inline bool is_navi3_supported()
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
}
inline bool is_navi4_supported() { return ck::get_device_name() == "gfx1200"; }
inline bool is_navi4_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
} // namespace ck
......@@ -488,7 +488,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
{
#ifdef __gfx12__
asm volatile("\
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
asm volatile("s_barrier" ::);
#endif
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
......
......@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
......@@ -70,8 +70,9 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
{
bool pass = true;
pass = pass && arg.K_ % K1 == 0;
......
......@@ -1394,7 +1394,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported()))
ck::is_navi3_supported() || ck::is_navi4_supported()))
{
return false;
}
......
......@@ -50,8 +50,9 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
......@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
......@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
......@@ -90,8 +90,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()))
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported()))
{
return false;
}
......
......@@ -107,7 +107,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx11__))
defined(__gfx11__) || defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported()))
ck::is_navi3_supported() || ck::is_navi4_supported()))
{
return false;
}
......
......@@ -39,8 +39,9 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......@@ -668,7 +669,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment