Unverified Commit 941d1f7c authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merging the gfx12 code into public repo. (#1362)

parent a32b1bc6
...@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
} }
}; };
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
template <>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx12__)
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx12__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
neg_a,
bit_cast<int32x2_t>(reg_a),
neg_b,
bit_cast<int32x2_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -203,7 +203,7 @@ struct vector_type<T, 1> ...@@ -203,7 +203,7 @@ struct vector_type<T, 1>
} }
}; };
int static err = 0; __device__ int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
......
...@@ -10,12 +10,20 @@ namespace ck { ...@@ -10,12 +10,20 @@ namespace ck {
__device__ void block_sync_lds() __device__ void block_sync_lds()
{ {
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#ifdef __gfx12__
asm volatile("\
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
// asm volatile("\ // asm volatile("\
// s_waitcnt lgkmcnt(0) \n \ // s_waitcnt lgkmcnt(0) \n \
// s_barrier \ // s_barrier \
// " ::); // " ::);
__builtin_amdgcn_s_waitcnt(0xc07f); __builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
#endif
#else #else
__syncthreads(); __syncthreads();
#endif #endif
...@@ -23,11 +31,20 @@ __device__ void block_sync_lds() ...@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
__device__ void block_sync_lds_direct_load() __device__ void block_sync_lds_direct_load()
{ {
#ifdef __gfx12__
asm volatile("\
s_wait_vmcnt 0x0 \n \
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
asm volatile("\ asm volatile("\
s_waitcnt vmcnt(0) \n \ s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \
s_barrier \ s_barrier \
" ::); " ::);
#endif
} }
__device__ void s_nop() __device__ void s_nop()
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__ #define __gfx11__
#endif #endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
...@@ -155,7 +158,7 @@ ...@@ -155,7 +158,7 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code #elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) // for GPU code #elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
......
...@@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME) ...@@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach() endforeach()
# Do not build WMMA instances if gfx11 targets are not on the target list # Do not build WMMA instances if gfx11 targets are not on the target list
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma instance ${source} ") message("removing wmma instance ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
...@@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list}) ...@@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list})
message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11")) if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12"))
message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
...@@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list}) ...@@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list})
message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9")) if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9"))
message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
......
...@@ -59,7 +59,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -59,7 +59,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endif() endif()
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
endif() endif()
...@@ -134,7 +134,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -134,7 +134,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
endif() endif()
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
endif() endif()
......
...@@ -60,7 +60,7 @@ function(add_test_executable TEST_NAME) ...@@ -60,7 +60,7 @@ function(add_test_executable TEST_NAME)
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
message("removing wmma test ${source} ") message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
...@@ -139,7 +139,7 @@ function(add_gtest_executable TEST_NAME) ...@@ -139,7 +139,7 @@ function(add_gtest_executable TEST_NAME)
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
message("removing wmma test ${source} ") message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
......
...@@ -44,7 +44,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -44,7 +44,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
} }
} }
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
// on gfx11 only support for 3d is implemented // on gfx11 only support for 3d is implemented
if constexpr(NDimSpatial{} != 3) if constexpr(NDimSpatial{} != 3)
......
...@@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) ...@@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele];
} }
#ifdef __gfx12__
asm volatile("\
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
asm volatile("\ asm volatile("\
s_waitcnt lgkmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \
s_barrier \ s_barrier \
" ::); " ::);
#endif
for(int ele = 0; ele < 16; ++ele) for(int ele = 0; ele < 16; ++ele)
{ {
...@@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) ...@@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8];
} }
#ifdef __gfx12__
asm volatile("\
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
asm volatile("\ asm volatile("\
s_waitcnt lgkmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \
s_barrier \ s_barrier \
" ::); " ::);
#endif
// sync threads, similar to mma_sync // sync threads, similar to mma_sync
// __syncthreads(); // __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment