Commit 4c683df4 authored by illsilin's avatar illsilin
Browse files

add support for more dl kernels on navi4

parent 69ad91b2
...@@ -70,9 +70,7 @@ ...@@ -70,9 +70,7 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) #elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) #elif defined(__gfx11__) || defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#elif defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
...@@ -85,14 +83,10 @@ ...@@ -85,14 +83,10 @@
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8 #define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx11__) #elif defined(__gfx11__) || defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11 #define CK_USE_AMD_V_DOT4_I32_I8_GFX11
#elif defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX12
#endif #endif
// MFMA instruction // MFMA instruction
......
...@@ -70,8 +70,9 @@ __global__ void ...@@ -70,8 +70,9 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
bool pass = true; bool pass = true;
pass = pass && arg.K_ % K1 == 0; pass = pass && arg.K_ % K1 == 0;
......
...@@ -1394,7 +1394,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1394,7 +1394,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported())) ck::is_navi3_supported() || ck::is_navi4_supported()))
{ {
return false; return false;
} }
......
...@@ -537,7 +537,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -537,7 +537,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported()) ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
...@@ -50,8 +50,9 @@ __global__ void ...@@ -50,8 +50,9 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
...@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout, ...@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
...@@ -90,8 +90,9 @@ __global__ void ...@@ -90,8 +90,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())) ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported()))
{ {
return false; return false;
} }
......
...@@ -107,7 +107,7 @@ __global__ void ...@@ -107,7 +107,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx11__)) defined(__gfx11__) || defined(__gfx12__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::is_navi3_supported())) ck::is_navi3_supported() || ck::is_navi4_supported()))
{ {
return false; return false;
} }
......
...@@ -39,8 +39,9 @@ __global__ void ...@@ -39,8 +39,9 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -668,7 +669,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -668,7 +669,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
} }
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_navi2_supported() || ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
......
...@@ -12,7 +12,7 @@ __device__ void block_sync_lds() ...@@ -12,7 +12,7 @@ __device__ void block_sync_lds()
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#ifdef __gfx12__ #ifdef __gfx12__
asm volatile("\ asm volatile("\
s_wait_idle \n \ s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \ s_barrier_signal -1 \n \
s_barrier_wait -1 \ s_barrier_wait -1 \
" ::); " ::);
...@@ -31,7 +31,8 @@ __device__ void block_sync_lds_direct_load() ...@@ -31,7 +31,8 @@ __device__ void block_sync_lds_direct_load()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
asm volatile("\ asm volatile("\
s_wait_idle \n \ s_wait_vmcnt 0x0 \n \
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \ s_barrier_signal -1 \n \
s_barrier_wait -1 \ s_barrier_wait -1 \
" ::); " ::);
......
...@@ -55,7 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -55,7 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
} }
} }
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
// on navi3x only support for 3d is implemented // on navi3x only support for 3d is implemented
if constexpr(NDimSpatial{} != 3) if constexpr(NDimSpatial{} != 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment