"vscode:/vscode.git/clone" did not exist on "a61b6472d25bf9e69365f30f6784d4019f4c578f"
Unverified Commit 180f16f9 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Add support for more Navi2x and Navi3x models. (#1152)

* add support for navi2x and navi3x models

* fix syntax

* use common macro for different mi300 architectures
parent 171ca260
......@@ -39,9 +39,8 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......@@ -668,26 +667,24 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
return false;
}
const std::string device_name = ck::get_device_name();
// TODO add newer Navi arch
if(device_name != "gfx906" and device_name != "gfx908" and device_name != "gfx90a" and
device_name != "gfx1030")
{
return false;
}
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
{
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
return false;
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
{
return false;
}
}
return true;
}
else
{
return false;
}
return true;
}
// polymorphic
......
......@@ -44,7 +44,7 @@ __global__ void
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......
......@@ -39,7 +39,7 @@ __global__ void
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......
......@@ -50,7 +50,7 @@ __global__ void
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......
......@@ -35,7 +35,7 @@ __global__ void
const index_t group_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
......
......@@ -57,7 +57,7 @@ __global__ void
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
......
......@@ -67,7 +67,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
......@@ -28,8 +28,7 @@ __global__ void
#endif
kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
......
......@@ -54,8 +54,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -148,8 +147,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
......@@ -244,8 +242,7 @@ __global__ void
const CDEElementwiseOperation cde_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -274,7 +271,7 @@ __global__ void
ignore = b_element_op;
ignore = cde_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__ ))
#endif // end of if (defined(__gfx11__ ))
}
template < // DataType Family
......
......@@ -55,8 +55,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
......@@ -55,7 +55,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
......@@ -49,8 +49,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -75,7 +74,7 @@ __global__ void
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx11__))
}
template <index_t BlockSize,
......
......@@ -25,7 +25,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
......@@ -50,7 +50,7 @@ __global__ void
typename GridwiseGemm::Problem problem)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
......
......@@ -26,7 +26,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -54,7 +54,7 @@ __global__ void
typename GridwiseGemm::Problem problem)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......
......@@ -58,7 +58,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue
......
......@@ -167,7 +167,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
......@@ -45,7 +45,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
......
......@@ -38,7 +38,7 @@ __global__ void
typename GridwiseGemm::Block2CTileMap block_mapping)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
......
......@@ -39,7 +39,7 @@ __global__ void
const CGridDesc_M_N c_grid_desc_m_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -70,7 +70,7 @@ __global__ void
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_k0_m_k1 =
......
......@@ -43,7 +43,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx94__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment