Commit 03874cbd authored by illsilin's avatar illsilin
Browse files

initial navi4x enablement

parent e60c5aea
......@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
......
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102 gfx1103)
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
......@@ -6,7 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
set(target 1)
endif()
endforeach()
......
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
......
......@@ -58,6 +58,9 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
......@@ -70,6 +73,9 @@
#elif defined(__gfx11__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#elif defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
// FMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
......@@ -84,6 +90,10 @@
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
#elif defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX12
#endif
// MFMA instruction
......@@ -104,7 +114,7 @@
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#elif defined(__gfx11__) // || defined(__gfx12__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
......
......@@ -84,5 +84,8 @@ inline bool is_navi3_supported()
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
}
inline bool is_navi4_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
} // namespace ck
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
......@@ -16,4 +16,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
set(target 1)
endif()
endforeach()
\ No newline at end of file
endforeach()
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment