Commit 2f88070a authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into e2e_kernellib
parents 0bb08f4b b076a02a
......@@ -60,7 +60,8 @@ __global__ void
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -554,7 +555,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}
......
......@@ -51,7 +51,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -490,7 +491,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}
......
......@@ -428,7 +428,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false;
}
}
else if(ck::get_device_name() == "gfx90a")
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
......@@ -574,7 +574,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}
......
......@@ -648,7 +648,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}
......
......@@ -37,7 +37,8 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......@@ -703,7 +704,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}
......
......@@ -130,7 +130,8 @@ __global__ void
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......
......@@ -78,7 +78,8 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
......@@ -155,7 +155,8 @@ __global__ void
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......@@ -810,7 +811,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return false;
}
}
else if(get_device_name() == "gfx90a")
else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
......@@ -135,7 +135,8 @@ __global__ void
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -684,7 +685,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
return false;
}
}
else if(get_device_name() == "gfx90a")
else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
......@@ -38,7 +38,8 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......
......@@ -56,6 +56,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
......@@ -86,6 +92,23 @@ struct UnaryConvert
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, bhalf_t>::value, "Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct Scale
{
__host__ __device__ Scale(float scale) : scale_(scale) {}
......
......@@ -587,4 +587,52 @@ struct OffsettedBlockToCTileMap
index_t block_start_;
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.
*
* @paragraph Description
* This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
* blocks. The first 2D are regular 2D tiles created by division of output GEMM
* dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension,
* which denotes the number of blocks we use to divide work on GEMM K dimension onto.
*
* @tparam MPerBlock Output block tile size in M dimension.
* @tparam NPerBlock Output block tile size in N dimension.
*/
template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_3DGrid_KSplit
{
__host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
__host__ __device__ constexpr auto
CalculateGridSize(index_t M, index_t N, index_t k_split) const
{
// Create 3D grid
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return std::make_tuple(N0, M0, k_split);
}
template <typename TopIdx>
__device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
{
return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
};
} // namespace ck
......@@ -66,7 +66,8 @@ __global__ void
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
......@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#if CK_WORKAROUND_DENORM_FIX
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else
......
......@@ -54,7 +54,8 @@ __global__ void
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
......@@ -44,7 +44,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
......@@ -57,7 +57,8 @@ __global__ void
const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment