Unverified Commit a4f72a31 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Grouped Gemm with looping over the tiles. (#788)



* Introduce LocalBlockToCTileMap.

* Change the signature of CalculateBottomIndex() function which now does
not accept any argument. The B2C map which is already passed as an
argument to the kernel Run function is calculating block's local id
already outside at kernel entry point __global__ function.
The LocalB2C map stores as members local block ID.

* Use LocalBlockToCTile map in device ops.

* First draft of tile loop work distribution.

* Fix typo.

* Simplify kernel arguments.

Calculate descriptors & B2C maps on the device.

* Use looping kernel.

* Fix B2C constructor.

* Fix Navi21 errors.

* Calculate tile start/end in device kernel.

* Change Run API to accept user provided workspace buffer.

* Add new line at EOF.

* Move Gemm KernelArguments to device op interface.

* Remove unused code.

* Update API.

* Launch grid size which is min of occupancy vs tile count

* Get back to use constant memory for gemm descriptors.

* Remove unused code.

* Add default virtual method implementation.

* Update comments to conform with doxygen style.

* Fix doc style and unused parameters.

* Add thread cluster lengths to kernel name.

* Remove old splitk impl and replace it with tile looping one.

* Modify instances.

* set KPerBlock to 64
* maximize wherever possible vector load size.

* Fix instances cluster lengths.

* Change comment style.

* Use 128b store where possible in instances.

* Update test cases, since KPerBlock has doubled.

* Update output stream operator for Sequence.

* Add pipeline version to GroupedGEMM device op type string.

* Fix pipeline version type logging.

* Fix input tensors type after merge.

* Fix compiler error.

* Fix output stream operator for Pipeline version.

* Store using 128b.

* Set of instances with kpb 32/64

* Limit number of instances

* Remove commented out instances.

* Fix function name.

* Limit the number of instances.

Add pipline version to the regular instances

* Change thr cluster layout for reading B tensor.

* disabled failed instances

---------
Co-authored-by: default avatarAdam Osewski <aosewski@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarJing Zhang <jizha@amd.com>
parent 98c80714
...@@ -8,6 +8,57 @@ namespace ck { ...@@ -8,6 +8,57 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed
/// to the GroupedGEMM entry point kernel.
///
struct GroupedGemmKernelArguments
{
__host__ __device__ GroupedGemmKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
...@@ -31,7 +82,28 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout, ...@@ -31,7 +82,28 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; //----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize([[maybe_unused]] BaseArgument* p_arg,
[[maybe_unused]] index_t kbatch) const
{
}
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs([[maybe_unused]] BaseArgument* p_arg,
[[maybe_unused]] const void* p_dev_kernel_args) const
{
}
}; };
} // namespace device } // namespace device
......
...@@ -22,22 +22,22 @@ template <typename InDataType, ...@@ -22,22 +22,22 @@ template <typename InDataType,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceSoftmax : public BaseOperator struct DeviceSoftmax : public BaseOperator
{ {
// ///
// @brief Makes a pointer to Argument class. /// @brief Makes a pointer to Argument class.
// ///
// @param[in] inLengths Input tensor extent(s) from high to low dimension /// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension /// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied /// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha double type value /// @param[in] alpha double type value
// @param[in] beta double type value /// @param[in] beta double type value
// @param[in] in_dev Typeless const pointer in device memory storing the input /// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor /// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor /// @param out_dev Typeless pointer in device memory storing the output tensor
// @param[in] in_elementwise_op The input elementwise operation. /// @param[in] in_elementwise_op The input elementwise operation.
// @param[in] acc_elementwise_op The accumulation elementwise operation. /// @param[in] acc_elementwise_op The accumulation elementwise operation.
// ///
// @return Unique pointer to the Argument class. /// @return Unique pointer to the Argument class.
// ///
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
......
...@@ -168,7 +168,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -168,7 +168,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
stream_config.stream_id_)); stream_config.stream_id_));
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
......
...@@ -157,22 +157,22 @@ __global__ void ...@@ -157,22 +157,22 @@ __global__ void
} }
} // namespace } // namespace
// ///
// @brief Device Convolution operation. /// @brief Device Convolution operation.
// ///
// Supports: /// Supports:
// @li Forward convolution with up to 3 spatial dimentions /// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format /// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format /// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format /// @li Output tensor in GNWK data format
// ///
// 1D: /// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] /// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D: /// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] /// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D: /// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] /// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// ///
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
......
...@@ -154,22 +154,22 @@ __global__ void ...@@ -154,22 +154,22 @@ __global__ void
} // namespace } // namespace
// ///
// @brief Device Convolution operation. /// @brief Device Convolution operation.
// ///
// Supports: /// Supports:
// @li Forward convolution with up to 3 spatial dimentions /// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format /// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format /// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format /// @li Output tensor in GNWK data format
// ///
// 1D: /// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] /// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D: /// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] /// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D: /// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] /// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// ///
template < template <
index_t NDimSpatial, index_t NDimSpatial,
typename ADataType, typename ADataType,
......
...@@ -150,22 +150,22 @@ __global__ void ...@@ -150,22 +150,22 @@ __global__ void
} // namespace } // namespace
// ///
// @brief Device Convolution operation. /// @brief Device Convolution operation.
// ///
// Supports: /// Supports:
// @li Forward convolution with up to 3 spatial dimentions /// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format /// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format /// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format /// @li Output tensor in GNWK data format
// ///
// 1D: /// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] /// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D: /// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] /// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D: /// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] /// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// ///
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
......
...@@ -348,24 +348,24 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -348,24 +348,24 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
acc_elementwise_op}; acc_elementwise_op};
}; };
// ///
// @brief Makes a pointer to Argument class. /// @brief Makes a pointer to Argument class.
// ///
// @param[in] inLengths Input tensor extent(s) from high to low dimension /// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension /// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied /// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling /// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType /// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling /// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType /// value as type AccDataType
// @param[in] in_dev Typeless const pointer in device memory storing the input /// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor /// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor /// @param out_dev Typeless pointer in device memory storing the output tensor
// @param[in] in_elementwise_op The input elementwise operation. /// @param[in] in_elementwise_op The input elementwise operation.
// @param[in] acc_elementwise_op The accumulation elementwise operation. /// @param[in] acc_elementwise_op The accumulation elementwise operation.
// ///
// @return Unique pointer to the Argument class. /// @return Unique pointer to the Argument class.
// ///
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths, std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
......
...@@ -271,7 +271,8 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt ...@@ -271,7 +271,8 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
{ {
} }
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
...@@ -624,24 +625,36 @@ struct OffsettedBlockToCTileMap ...@@ -624,24 +625,36 @@ struct OffsettedBlockToCTileMap
index_t block_start_; index_t block_start_;
}; };
/** ///
* @brief Simple tile mapping which creates 3D grid of block of threads. /// @brief Simple tile mapping which creates 3D grid of block of threads.
* ///
* @paragraph Description /// @paragraph Description
* This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread /// This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
* blocks. The first 2D are regular 2D tiles created by division of output GEMM /// blocks. The first 2D are regular 2D tiles created by division of output GEMM
* dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension, /// dimenions by corresponding tile size. The third dimension (Z) is a k-split
* which denotes the number of blocks we use to divide work on GEMM K dimension onto. /// dimension, which denotes the number of blocks we use to divide work on GEMM K
* /// dimension onto.
* @tparam MPerBlock Output block tile size in M dimension. ///
* @tparam NPerBlock Output block tile size in N dimension. /// @tparam MPerBlock Output block tile size in M dimension.
*/ /// @tparam NPerBlock Output block tile size in N dimension.
///
template <index_t MPerBlock, index_t NPerBlock> template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_3DGrid_KSplit struct BlockToCTileMap_3DGrid_KSplit
{ {
__host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default; __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
///
/// @brief Constructs a new instance.
///
/// @param[in] top_idx Swallow blockIdx.
///
/// @tparam TopIdx The type of block index.
///
template <typename TopIdx>
__host__ __device__ BlockToCTileMap_3DGrid_KSplit([[maybe_unused]] TopIdx top_idx)
{
}
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
CalculateGridSize(index_t M, index_t N, index_t k_split) const CalculateGridSize(index_t M, index_t N, index_t k_split) const
{ {
...@@ -652,8 +665,7 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -652,8 +665,7 @@ struct BlockToCTileMap_3DGrid_KSplit
return std::make_tuple(N0, M0, k_split); return std::make_tuple(N0, M0, k_split);
} }
template <typename TopIdx> __device__ constexpr auto CalculateBottomIndex() const
__device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
{ {
return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x); return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
} }
...@@ -672,6 +684,53 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -672,6 +684,53 @@ struct BlockToCTileMap_3DGrid_KSplit
} }
}; };
///
/// @brief Block to CTile Map which foster external mechanism for setting up local block id.
///
/// In example this type can be easily used to implement tile looping work distribution
/// scheme.
///
/// @tparam UnderlyingBlockToCTileMap The type of the local tile mapp.
///
template <typename UnderlyingBlockToCTileMap>
struct LocalBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ LocalBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t local_id)
: block_to_ctile_map_{block_to_ctile_map}, local_block_id_{local_id}
{
}
__host__ __device__ constexpr auto CalculateBottomIndex() const
{
return block_to_ctile_map_.CalculateBottomIndex(make_multi_index(local_block_id_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t local_block_id_;
};
enum StreamKReductionStrategy enum StreamKReductionStrategy
{ {
Atomic = 0, // sk block use atomic to do reduction Atomic = 0, // sk block use atomic to do reduction
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <ostream>
#include <string>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...@@ -42,4 +44,20 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -42,4 +44,20 @@ constexpr auto GridwiseGemmPipeline_Selector()
} }
} }
inline std::string getPipelineVersionString(const PipelineVersion& pv)
{
switch(pv)
{
case PipelineVersion::v1: return "PipelineVersion::v1";
case PipelineVersion::v2: return "PipelineVersion::v2";
default: return "Unrecognized pipeline version!";
}
}
} // namespace ck } // namespace ck
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion pv)
{
os << ck::getPipelineVersionString(pv);
return os;
}
...@@ -27,8 +27,7 @@ __global__ void ...@@ -27,8 +27,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
const Block2CTileMap& b2c_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -36,11 +35,12 @@ __global__ void ...@@ -36,11 +35,12 @@ __global__ void
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
Block2CTileMap b2c_map{get_block_1d_id()};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared), b2c_map); karg, static_cast<void*>(p_shared), b2c_map);
#else #else
ignore = karg; ignore = karg;
ignore = b2c_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -541,15 +541,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -541,15 +541,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
} }
// return block_id to C matrix tile idx (m0, n0) mapping
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
c_m_n_grid_desc, 8, KBatch);
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
...@@ -575,18 +566,28 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -575,18 +566,28 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap> typename Block2CTileMap>
__device__ static void Run(const Argument& karg, __device__ static void Run(const FloatA* p_a_grid,
const FloatB* p_b_grid,
FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t MPadded,
index_t NPadded,
index_t KPadded,
index_t K0,
index_t k_batch,
void* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const FloatA* p_a_grid = karg.p_a_grid; const auto a_b_k0_m_k1_grid_desc =
const FloatB* p_b_grid = karg.p_b_grid; MakeAGridDescriptor_KBatch_K0_M_K1(M, MPadded, K, StrideA, k_batch, K0, KPadded);
FloatC* p_c_grid = karg.p_c_grid; const auto b_b_k0_n_k1_grid_desc =
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( MakeBGridDescriptor_KBatch_K0_N_K1(K, NPadded, N, StrideB, k_batch, K0, KPadded);
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(M, N, StrideC);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
...@@ -602,8 +603,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -602,8 +603,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [KBatch, M, N] // divide block work by [KBatch, M, N]
const auto block_work_idx = const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex();
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
...@@ -1010,6 +1010,34 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -1010,6 +1010,34 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
} }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap>
__device__ static void Run(const Argument& karg,
void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, Block2CTileMap>(karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.M,
karg.N,
karg.K,
karg.StrideA,
karg.StrideB,
karg.StrideC,
karg.MPadded,
karg.NPadded,
karg.KPadded,
karg.K0,
karg.k_batch,
p_shared_block,
block_2_ctile_map);
}
static constexpr auto GetMPerBlock() { return MPerBlock; }
static constexpr auto GetNPerBlock() { return NPerBlock; }
static std::string GetTypeString() static std::string GetTypeString()
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
...@@ -897,3 +897,14 @@ template <index_t NSize, index_t I> ...@@ -897,3 +897,14 @@ template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type; using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck } // namespace ck
template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{
using S = ck::Sequence<Is...>;
os << "{";
ck::static_for<0, S::Size() - ck::Number<1>{}, 1>{}(
[&](auto i) { os << S::At(i).value << ", "; });
os << S::At(S::Size() - ck::Number<1>{}).value << "}";
return os;
}
...@@ -14,27 +14,27 @@ namespace ck { ...@@ -14,27 +14,27 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// ///
// @brief Reference implementation for forward convolution. /// @brief Reference implementation for forward convolution.
// ///
// @paragraph /// @paragraph
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order /// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout /// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order /// as long as dimensions in tensor descriptor is in GNCHW order
// ///
// @tparam InDataType Input tensor data type. /// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type. /// @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type. /// @tparam OutDataType Output tensor data type.
// @tparam InElementwiseOperation Functor for input tensor elementwise /// @tparam InElementwiseOperation Functor for input tensor elementwise
// operation. /// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise /// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation. /// operation.
// @tparam NDimSpatial Number of spatial dimensions. /// @tparam NDimSpatial Number of spatial dimensions.
// ///
// input descriptor in [G, N, C, Do, Ho, Wo] order /// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order /// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order /// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent /// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
......
...@@ -12,7 +12,9 @@ cmake ...@@ -12,7 +12,9 @@ cmake
-save-temps=$PWD" \ -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \ -D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
#-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment