Commit 5e6cca6f authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents afc7d431 3956085d
...@@ -42,7 +42,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -42,7 +42,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev \ libnuma-dev \
libpthread-stubs0-dev \ libpthread-stubs0-dev \
llvm-amdgpu \ llvm-amdgpu \
miopengemm \
pkg-config \ pkg-config \
python \ python \
python3 \ python3 \
...@@ -51,19 +50,15 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -51,19 +50,15 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
python-pip \ python-pip \
python3-pip \ python3-pip \
software-properties-common \ software-properties-common \
sqlite3 \
wget \ wget \
rocm-dev \ rocm-dev \
rocm-device-libs \ rocm-device-libs \
rocm-opencl \
rocm-opencl-dev \
rocm-cmake \ rocm-cmake \
rocblas \
vim \ vim \
zlib1g-dev \ zlib1g-dev \
openssh-server \ openssh-server \
kmod \ clang-format-10 \
mysql-client && \ kmod && \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
......
...@@ -204,7 +204,7 @@ pipeline { ...@@ -204,7 +204,7 @@ pipeline {
stage('Clang Format') { stage('Clang Format') {
agent{ label rocmnode("nogpu") } agent{ label rocmnode("nogpu") }
environment{ environment{
execute_cmd = "find . -iname \'*.h\' \ execute_cmd = "find .. -iname \'*.h\' \
-o -iname \'*.hpp\' \ -o -iname \'*.hpp\' \
-o -iname \'*.cpp\' \ -o -iname \'*.cpp\' \
-o -iname \'*.h.in\' \ -o -iname \'*.h.in\' \
......
add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp) add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp)
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp) add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp)
target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp) add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp)
target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_fwd_util)
...@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device:: ...@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
using ReferenceConvBwdWeightInstance = ck::tensor_operation::host:: using ReferenceConvBwdWeightInstance =
ReferenceConvBwdWeight<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>; ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <getopt.h> #include <getopt.h>
#include <half.hpp>
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
...@@ -27,10 +26,6 @@ using InDataType = ck::half_t; ...@@ -27,10 +26,6 @@ using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using HostInDataType = half_float::half;
using HostOutDataType = half_float::half;
using HostAccDataType = float;
constexpr int Rank = 4; constexpr int Rank = 4;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
...@@ -306,9 +301,9 @@ int main(int argc, char* argv[]) ...@@ -306,9 +301,9 @@ int main(int argc, char* argv[])
if(args.do_verification) if(args.do_verification)
{ {
ReductionHost<HostInDataType, ReductionHost<InDataType,
HostAccDataType, AccDataType,
HostOutDataType, OutDataType,
ReduceOpId, ReduceOpId,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -316,11 +311,8 @@ int main(int argc, char* argv[]) ...@@ -316,11 +311,8 @@ int main(int argc, char* argv[])
NeedIndices> NeedIndices>
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run(alpha, hostReduce.Run(
reinterpret_cast<const HostInDataType*>(in.mData.data()), alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
beta,
reinterpret_cast<HostOutDataType*>(out_ref.mData.data()),
out_indices_ref.mData.data());
}; };
const auto i_inLengths = to_int_vector(args.inLengths); const auto i_inLengths = to_int_vector(args.inLengths);
......
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp) add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_fwd_util)
...@@ -39,6 +39,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -39,6 +39,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
...@@ -71,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -71,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], Number<KPack>{} * xdlops_a_idx[I0]); return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -82,7 +84,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -82,7 +84,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], Number<KPack>{} * xdlops_b_idx[I0]); return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
...@@ -273,7 +275,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -273,7 +275,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
static_for<0, KPerBlock, KPack * xdlops_gemm.K0PerXdlops>{}([&](auto k) { static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
...@@ -300,13 +302,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -300,13 +302,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
} }
private: private:
// A[M0, M1, M2, KPerBlock] // A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerBlock] // B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// C[M, N, NumRegXdlops] // C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -316,7 +318,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -316,7 +318,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatAB, FloatAB,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerBlock>, Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
A_K1, A_K1,
...@@ -326,7 +328,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -326,7 +328,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatAB, FloatAB,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerBlock>, Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
B_K1, B_K1,
......
...@@ -54,6 +54,7 @@ __global__ void ...@@ -54,6 +54,7 @@ __global__ void
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -88,6 +89,25 @@ __global__ void ...@@ -88,6 +89,25 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_d0_grid;
ignore = p_d1_grid;
ignore = batch_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d0_reduce_op;
ignore = d1_reduce_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock;
ignore = compute_base_ptr_of_batch_;
ignore = block_2_ctile_map;
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename ALayout, template <typename ALayout,
......
...@@ -16,6 +16,31 @@ namespace ck { ...@@ -16,6 +16,31 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -25,7 +50,7 @@ template <typename GridwiseGemm, ...@@ -25,7 +50,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputeBasePrtOfBatch, typename ComputePtrOffsetOfBatch,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -43,19 +68,20 @@ __global__ void ...@@ -43,19 +68,20 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -70,6 +96,20 @@ __global__ void ...@@ -70,6 +96,20 @@ __global__ void
b_element_op, b_element_op,
c_element_op, c_element_op,
block_2_ctile_map); block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_base_ptr_of_batch_;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename ADataType, template <typename ADataType,
...@@ -241,26 +281,26 @@ struct DeviceBatchedGemmXdl ...@@ -241,26 +281,26 @@ struct DeviceBatchedGemmXdl
return globalblockid_to_m0_n0_block_cluster_adaptor; return globalblockid_to_m0_n0_block_cluster_adaptor;
} }
struct ComputeBasePtrOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
index_t BatchStrideC) index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{ {
} }
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return g_idx * static_cast<long_index_t>(BatchStrideA_);
} }
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return g_idx * static_cast<long_index_t>(BatchStrideB_);
} }
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideC_); return g_idx * static_cast<long_index_t>(BatchStrideC_);
} }
...@@ -344,9 +384,9 @@ struct DeviceBatchedGemmXdl ...@@ -344,9 +384,9 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_base_ptr_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(), compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
b_grid_desc_k0_n_k1_.GetElementSpaceSize(), b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
c_grid_desc_m_n_.GetElementSpaceSize()}, c_grid_desc_m_n_.GetElementSpaceSize()},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -373,7 +413,7 @@ struct DeviceBatchedGemmXdl ...@@ -373,7 +413,7 @@ struct DeviceBatchedGemmXdl
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -433,7 +473,7 @@ struct DeviceBatchedGemmXdl ...@@ -433,7 +473,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
ComputeBasePtrOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
...@@ -452,7 +492,7 @@ struct DeviceBatchedGemmXdl ...@@ -452,7 +492,7 @@ struct DeviceBatchedGemmXdl
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.compute_base_ptr_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
...@@ -467,7 +507,7 @@ struct DeviceBatchedGemmXdl ...@@ -467,7 +507,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
ComputeBasePtrOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
...@@ -486,7 +526,7 @@ struct DeviceBatchedGemmXdl ...@@ -486,7 +526,7 @@ struct DeviceBatchedGemmXdl
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.compute_base_ptr_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
......
...@@ -18,6 +18,9 @@ namespace ck { ...@@ -18,6 +18,9 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -49,6 +52,7 @@ __global__ void ...@@ -49,6 +52,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -73,6 +77,23 @@ __global__ void ...@@ -73,6 +77,23 @@ __global__ void
b_element_op, b_element_op,
c_element_op, c_element_op,
block_2_ctile_map); block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = num_batches;
ignore = a_batch_stride;
ignore = b_batch_stride;
ignore = c_batch_stride;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] // specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
......
...@@ -48,6 +48,7 @@ __global__ void ...@@ -48,6 +48,7 @@ __global__ void
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
...@@ -66,6 +67,23 @@ __global__ void ...@@ -66,6 +67,23 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_d0_grid;
ignore = p_d1_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d0_reduce_op;
ignore = d1_reduce_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename FloatAB, template <typename FloatAB,
......
...@@ -38,6 +38,7 @@ __global__ void ...@@ -38,6 +38,7 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
...@@ -51,6 +52,18 @@ __global__ void ...@@ -51,6 +52,18 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename FloatAB, template <typename FloatAB,
......
...@@ -39,6 +39,7 @@ __global__ void ...@@ -39,6 +39,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
...@@ -52,6 +53,18 @@ __global__ void ...@@ -52,6 +53,18 @@ __global__ void
b_element_op, b_element_op,
c_element_op, c_element_op,
block_2_ctile_map); block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -74,6 +87,7 @@ __global__ void ...@@ -74,6 +87,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -126,6 +140,13 @@ __global__ void ...@@ -126,6 +140,13 @@ __global__ void
gemm_desc_ptr[group_id].block_2_ctile_map_, gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp); block_id_grp);
#endif #endif
#else
ignore = gemm_desc_;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <index_t BlockSize, template <index_t BlockSize,
......
...@@ -37,6 +37,7 @@ __global__ void ...@@ -37,6 +37,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -53,6 +54,18 @@ __global__ void ...@@ -53,6 +54,18 @@ __global__ void
b_element_op, b_element_op,
c_element_op, c_element_op,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <index_t BlockSize, template <index_t BlockSize,
......
...@@ -39,6 +39,7 @@ __global__ void ...@@ -39,6 +39,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -55,6 +56,18 @@ __global__ void ...@@ -55,6 +56,18 @@ __global__ void
b_element_op, b_element_op,
c_element_op, c_element_op,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <index_t BlockSize, template <index_t BlockSize,
......
...@@ -42,6 +42,7 @@ __global__ void ...@@ -42,6 +42,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
...@@ -56,6 +57,18 @@ __global__ void ...@@ -56,6 +57,18 @@ __global__ void
b_element_op, b_element_op,
c_element_op, c_element_op,
block_2_ctile_map); block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template < template <
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment