"vscode:/vscode.git/clone" did not exist on "643ea61a98d1f98850bb9da83caec3fe5026a4d8"
Commit b79df771 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents 05d38218 63914743
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "ck/utility/common_header.hpp"
#include "tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "cluster_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "threadwise_tensor_slice_transfer_v6r3.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
namespace ck {
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template <typename ThreadGroup,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags>
struct ThreadGroupTensorSliceTransfer_v7
{
static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
__device__ constexpr ThreadGroupTensorSliceTransfer_v7(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op)
{
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
"wrong!");
static_for<0, nSrc, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
"wrong!");
});
static_for<0, nDst, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
"wrong!");
});
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
Number<nSrc>{});
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
}
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs);
}
}
template <index_t ISrc>
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
}
}
template <index_t IDst>
__device__ void
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7<SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION #ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION #define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION #ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION #define CONVOLUTION_FORWARD_SPECIALIZATION
...@@ -15,7 +18,7 @@ enum struct ConvolutionForwardSpecialization ...@@ -15,7 +18,7 @@ enum struct ConvolutionForwardSpecialization
OddC, OddC,
}; };
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s) inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s)
{ {
switch(s) switch(s)
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp" #include <vector>
#include "device_base.hpp"
#include "common_header.hpp" #include "ck/utility/common_header.hpp"
#include "gridwise_5ary_Elementwise_1d.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "tensor_layout.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -29,7 +35,7 @@ template <typename ADataType, ...@@ -29,7 +35,7 @@ template <typename ADataType,
index_t DScalarPerVector, index_t DScalarPerVector,
index_t EScalarPerVector, index_t EScalarPerVector,
index_t FScalarPerVector> index_t FScalarPerVector>
struct Device5AryElementwise : public BaseOperator struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -262,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator ...@@ -262,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return true; return true;
}; };
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(std::array<const void*, 5> p_inputs,
const BDataType* p_b, std::array<void*, 1> p_outputs,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<index_t> a_strides,
std::vector<index_t> b_strides, std::vector<index_t> b_strides,
...@@ -277,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator ...@@ -277,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std::vector<index_t> f_strides, std::vector<index_t> f_strides,
ElementwiseFunctor functor) ElementwiseFunctor functor)
{ {
return Argument{p_a, return Argument{static_cast<const ADataType*>(p_inputs[0]),
p_b, static_cast<const BDataType*>(p_inputs[1]),
p_c, static_cast<const CDataType*>(p_inputs[2]),
p_d, static_cast<const DDataType*>(p_inputs[3]),
p_e, static_cast<const EDataType*>(p_inputs[4]),
p_f, static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, a_strides,
b_strides, b_strides,
...@@ -293,39 +295,57 @@ struct Device5AryElementwise : public BaseOperator ...@@ -293,39 +295,57 @@ struct Device5AryElementwise : public BaseOperator
functor}; functor};
} }
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 5> p_inputs,
const void* p_c, std::array<void*, 1> p_outputs,
const void* p_d, std::vector<index_t> lengths,
const void* p_e, std::vector<std::vector<index_t>> input_strides,
void* p_f, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> lengths, ElementwiseFunctor functor) override
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_c), static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_d), static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_e), static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_f), static_cast<FDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, input_strides[2],
d_strides, input_strides[3],
e_strides, input_strides[4],
f_strides, output_strides[0],
functor); functor);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device5aryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< "DScalarPerVector = " << DScalarPerVector
<< "EScalarPerVector = " << EScalarPerVector
<< "FScalarPerVector = " << FScalarPerVector
<< ">";
// clang-format on
return str.str();
}
}; // namespace device }; // namespace device
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <string> #include <string>
#include "stream_config.hpp" #include "ck/stream_config.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -15,6 +18,8 @@ struct BaseArgument ...@@ -15,6 +18,8 @@ struct BaseArgument
BaseArgument& operator=(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {} virtual ~BaseArgument() {}
void* p_workspace_ = nullptr;
}; };
struct BaseInvoker struct BaseInvoker
...@@ -42,7 +47,11 @@ struct BaseOperator ...@@ -42,7 +47,11 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {} virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
{
assert(p_arg);
p_arg->p_workspace_ = p_workspace;
}
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
ck::index_t Batch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmPtr = std::unique_ptr<DeviceBatchedGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector>
#include "device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct BatchedGemmCPermuteDesc
{
ck::index_t G0_, G1_, M_, N_;
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
};
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemmBias : public BaseOperator struct DeviceBatchedGemmCPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
void* p_c, void* p_c,
ck::index_t M, index_t M,
ck::index_t N, index_t N,
ck::index_t K, index_t K,
ck::index_t StrideA, index_t stride_A,
ck::index_t StrideB, index_t stride_B,
ck::index_t StrideC, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op,
ck::index_t BatchCount) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
...@@ -32,8 +40,8 @@ struct DeviceGemmBias : public BaseOperator ...@@ -32,8 +40,8 @@ struct DeviceGemmBias : public BaseOperator
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGemmBiasPtr = std::unique_ptr< using DeviceBatchedGemmCPermutePtr = std::unique_ptr<
DeviceGemmBias<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
ck::Tuple<>{},
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
0>{},
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
template <typename ALayout,
typename BLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{
const auto c_grid_desc_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
index_t G1,
index_t MRaw,
index_t NRaw,
index_t stride_G0,
index_t stride_G1,
index_t stride_M,
index_t stride_N)
{
const auto e_grid_desc_g0_g1_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(
make_tuple(G0, G1, MRaw, NRaw),
make_tuple(stride_G0, stride_G1, stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(
e_grid_desc_g0_g1_mraw_nraw,
make_tuple(make_pass_through_transform(G0),
make_pass_through_transform(G1),
make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
e_grid_desc_g0_g1_mraw_nraw,
make_tuple(make_pass_through_transform(G0),
make_pass_through_transform(G1),
make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
e_grid_desc_g0_g1_mraw_nraw,
make_tuple(make_pass_through_transform(G0),
make_pass_through_transform(G1),
make_pass_through_transform(MRaw),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
else
{
// not pad M or N
return e_grid_desc_g0_g1_mraw_nraw;
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1));
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A,
index_t Batchstride_B,
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
: Batchstride_A_(Batchstride_A),
Batchstride_B_(Batchstride_B),
e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(Batchstride_A_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(Batchstride_B_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1);
index_t b0 = g_idx / G1;
index_t b1 = g_idx - b0 * G1; // g_idx % G1
return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0));
}
private:
index_t Batchstride_A_;
index_t Batchstride_B_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
};
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType, // CShuffleDataType,
ck::Tuple<>, // DsDataType,
CDataType, // EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
NumPrefetch,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
BatchCount_(BatchCount),
a_grid_desc_k0_m_k1_{
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)},
b_grid_desc_k0_n_k1_{
DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)},
c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N(
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N(
batched_gemm_c_permute_desc.G0_,
batched_gemm_c_permute_desc.G1_,
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_G0_,
batched_gemm_c_permute_desc.stride_G1_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
c_grid_desc_mblock_mperblock_nblock_nperblock{},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t BatchCount_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceBatchedGemmCPermuteXdl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
"setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_c_permute_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::BGridDesc_K0_N_K1>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
a_element_op,
b_element_op,
c_element_op,
BatchCount};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t BatchCount) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
a_element_op,
b_element_op,
c_element_op,
BatchCount);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmCPermuteXdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp"
#include "device_gemm_reduce.hpp" #include "ck/utility/common_header.hpp"
#include "common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "tensor_layout.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,16 +23,16 @@ namespace device { ...@@ -17,16 +23,16 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock, typename ReduceGridDescriptor_MBlock_MPerBlock,
typename ComputeBasePrtOfBatch, typename ComputeBasePrtOfBatch,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainK0BlockLoop> bool HasMainK0BlockLoop>
...@@ -38,18 +44,18 @@ __global__ void ...@@ -38,18 +44,18 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
const index_t batch_count, const index_t batch_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const ReduceInElementwiseOperations reduce_in_element_ops,
const DxsAccElementwiseOperation dxs_out_element_op, const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
...@@ -65,10 +71,10 @@ __global__ void ...@@ -65,10 +71,10 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset; p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset;
}); });
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -76,36 +82,36 @@ __global__ void ...@@ -76,36 +82,36 @@ __global__ void
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_ds_grid, p_reduces_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_ds_grid; ignore = p_reduces_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = dxs_in_element_op; ignore = reduce_in_element_ops;
ignore = dxs_out_element_op; ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock; ignore = reduce_grid_desc_mblock_mperblock;
ignore = compute_base_ptr_of_batch_; ignore = compute_base_ptr_of_batch_;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__)) #endif
} }
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
...@@ -120,14 +126,14 @@ template <typename ALayout, ...@@ -120,14 +126,14 @@ template <typename ALayout,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType, typename ReduceAccDataType,
typename DPtrsGlobal, typename ReducePtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename ReduceOperations,
typename DxsInElementwiseOperation, typename ReduceInElementwiseOperations,
typename DxsAccElementwiseOperation, typename ReduceAccElementwiseOperations,
typename DGlobalMemoryDataOperation, typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -162,12 +168,7 @@ template <typename ALayout, ...@@ -162,12 +168,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
...@@ -440,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -440,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
} }
// assume D is packed tensor // assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeReduceGridDescriptor_M(index_t MRaw)
{ {
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -468,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -468,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
...@@ -521,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -521,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
ReduceAccDataType, ReduceAccDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsReduceOperation, ReduceOperations,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsAccElementwiseOperation, ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation, ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
DGridDesc_M, ReduceGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -576,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -576,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
DPtrsGlobal p_ds_grid, ReducePtrsGlobal p_reduces_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -586,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -586,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, ReduceInElementwiseOperations reduce_in_element_ops,
DxsAccElementwiseOperation dxs_out_element_op, ReduceAccElementwiseOperations reduce_out_element_ops,
index_t BatchCount) index_t Batch)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_ds_grid_{p_ds_grid}, p_reduces_grid_{p_reduces_grid},
BatchCount_(BatchCount), Batch_(Batch),
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, reduce_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()), type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())}, type_convert<index_t>(reduce_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op}, reduce_in_element_ops_{reduce_in_element_ops},
dxs_out_element_op_{dxs_out_element_op} reduce_out_element_ops_{reduce_out_element_ops}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -621,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -621,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ = reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
} }
} }
...@@ -630,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -630,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
DPtrsGlobal p_ds_grid_; ReducePtrsGlobal p_reduces_grid_;
index_t BatchCount_; index_t Batch_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_; ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; ReduceInElementwiseOperations reduce_in_element_ops_;
DxsAccElementwiseOperation dxs_out_element_op_; ReduceAccElementwiseOperations reduce_out_element_ops_;
}; };
// Invoker // Invoker
...@@ -657,7 +659,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -657,7 +659,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
{ {
#if 0 #if 0
{ {
std::cout << "arg.BatchCount_ = " << arg.BatchCount_ << std::endl; std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
...@@ -672,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -672,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}" std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
<< std::endl; << std::endl;
} }
#endif #endif
...@@ -686,7 +688,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -686,7 +688,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -698,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -698,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
...@@ -721,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -721,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -741,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -741,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
DPtrsGlobal, ReducePtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, ReduceInElementwiseOperations,
DxsAccElementwiseOperation, ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
...@@ -764,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -764,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_ds_grid_, arg.p_reduces_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.dxs_in_element_op_, arg.reduce_in_element_ops_,
arg.dxs_out_element_op_, arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -818,77 +820,153 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -818,77 +820,153 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
} }
} }
static auto MakeArgument(const ADataType* p_a, static constexpr int NumReduce = ReduceOperations::Size();
const BDataType* p_b, static auto MakeArgument(const void* p_a,
CDataType* p_c, const void* p_b,
DPtrsGlobal p_dxs, const void* p_bias,
index_t MRaw, std::array<const void*, 0> p_ds,
index_t NRaw, void* p_c,
index_t KRaw, std::array<void*, NumReduce> p_reduces,
index_t StrideA, ck::index_t M,
index_t StrideB, ck::index_t N,
index_t StrideC, ck::index_t K,
AElementwiseOperation a_element_op, ck::index_t StrideA,
BElementwiseOperation b_element_op, ck::index_t StrideB,
CElementwiseOperation c_element_op, ck::index_t StrideC,
DxsInElementwiseOperation dxs_in_element_op, std::array<ck::index_t, 0> StrideDs,
DxsAccElementwiseOperation dxs_out_element_op, std::array<void*, 3> gemm_element_ops,
index_t BatchCount) std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch)
{ {
return Argument{p_a, (void)p_bias;
p_b, (void)p_ds;
p_c, (void)StrideDs;
p_dxs, (void)d_element_ops;
MRaw,
NRaw, ReducePtrsGlobal reduce_tuple = generate_tuple(
KRaw, [&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
BatchCount}; Batch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(const void* p_a,
void* p_c, const void* p_b,
DPtrsGlobal p_dxs, const void* p_bias,
index_t MRaw, std::array<const void*, 0> p_ds,
index_t NRaw, void* p_c,
index_t KRaw, std::array<void*, NumReduce> p_reduces,
index_t StrideA, ck::index_t M,
index_t StrideB, ck::index_t N,
index_t StrideC, ck::index_t K,
AElementwiseOperation a_element_op, ck::index_t StrideA,
BElementwiseOperation b_element_op, ck::index_t StrideB,
CElementwiseOperation c_element_op, ck::index_t StrideC,
DxsInElementwiseOperation dxs_in_element_op, std::array<ck::index_t, 0> StrideDs,
DxsAccElementwiseOperation dxs_out_element_op, std::array<void*, 3> gemm_element_ops,
index_t BatchCount) override std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch = 1) override
{ {
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
p_dxs, reduce_tuple,
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
dxs_in_element_op, reduce_in_element_ops,
dxs_out_element_op, reduce_out_element_ops,
BatchCount); Batch);
} }
// polymorphic // polymorphic
......
#ifndef DEVICE_BATCHED_GEMM_XDL_HPP // SPDX-License-Identifier: MIT
#define DEVICE_BATCHED_GEMM_XDL_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp"
#include "device_base.hpp" #include "ck/utility/common_header.hpp"
#include "device_gemm.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "common_header.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -109,7 +113,7 @@ __global__ void ...@@ -109,7 +113,7 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif
} }
template <typename ADataType, template <typename ADataType,
...@@ -147,8 +151,15 @@ template <typename ADataType, ...@@ -147,8 +151,15 @@ template <typename ADataType,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceBatchedGemmXdl struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -330,26 +341,26 @@ struct DeviceBatchedGemmXdl ...@@ -330,26 +341,26 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op)
index_t BatchCount)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
BatchCount_(BatchCount), Batch_(Batch),
a_grid_desc_k0_m_k1_{ a_grid_desc_k0_m_k1_{
DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)},
b_grid_desc_k0_n_k1_{ b_grid_desc_k0_n_k1_{
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{ block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01}, M01_{M01},
...@@ -372,7 +383,7 @@ struct DeviceBatchedGemmXdl ...@@ -372,7 +383,7 @@ struct DeviceBatchedGemmXdl
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t Batch_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -416,7 +427,7 @@ struct DeviceBatchedGemmXdl ...@@ -416,7 +427,7 @@ struct DeviceBatchedGemmXdl
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
...@@ -447,7 +458,7 @@ struct DeviceBatchedGemmXdl ...@@ -447,7 +458,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -481,7 +492,7 @@ struct DeviceBatchedGemmXdl ...@@ -481,7 +492,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -532,10 +543,13 @@ struct DeviceBatchedGemmXdl ...@@ -532,10 +543,13 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op)
index_t BatchCount)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -546,12 +560,15 @@ struct DeviceBatchedGemmXdl ...@@ -546,12 +560,15 @@ struct DeviceBatchedGemmXdl
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
Batch,
1, 1,
1, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op};
BatchCount};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -566,10 +583,13 @@ struct DeviceBatchedGemmXdl ...@@ -566,10 +583,13 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -580,12 +600,15 @@ struct DeviceBatchedGemmXdl ...@@ -580,12 +600,15 @@ struct DeviceBatchedGemmXdl
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
Batch,
1, 1,
1, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op);
BatchCount);
} }
// polymorphic // polymorphic
...@@ -616,4 +639,3 @@ struct DeviceBatchedGemmXdl ...@@ -616,4 +639,3 @@ struct DeviceBatchedGemmXdl
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "device.hpp" #include "ck/device_utility/device_prop.hpp"
#include "device_base.hpp" #include "ck/device_utility/kernel_launch.hpp"
#include "gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -20,7 +26,7 @@ template <typename ADataType, ...@@ -20,7 +26,7 @@ template <typename ADataType,
index_t AScalarPerVector, index_t AScalarPerVector,
index_t BScalarPerVector, index_t BScalarPerVector,
index_t CScalarPerVector> index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -193,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -193,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return true; return true;
}; };
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 2> p_inputs,
void* p_c, std::array<void*, 1> p_outputs,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<std::vector<index_t>> input_strides,
std::vector<index_t> b_strides, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> c_strides, ElementwiseFunctor functor) override
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, output_strides[0],
functor); functor);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
...@@ -221,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -221,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off // clang-format off
str << "DeviceBinaryElementwise" str << "DeviceBinaryElementwise"
<< "<" << "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread << "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
/******************************************************************************* // SPDX-License-Identifier: MIT
* // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once #pragma once
#include "device_base.hpp" #include "device_base.hpp"
......
/******************************************************************************* // SPDX-License-Identifier: MIT
* // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp"
#include "device_gemm.hpp" #include "ck/utility/common_header.hpp"
#include "device_cgemm.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "common_header.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -557,11 +538,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -557,11 +538,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
using Add = using Add = ck::tensor_operation::element_wise::Add;
ck::tensor_operation::binary_element_wise::Add<CDataType, CDataType, CDataType>; using Subtract = ck::tensor_operation::element_wise::Subtract;
using Substract = ck::tensor_operation::binary_element_wise:: using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
Substract<CDataType, CDataType, CDataType>;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
CDataType, CDataType,
CDataType, CDataType,
CDataType, CDataType,
...@@ -573,19 +552,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -573,19 +552,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AScalarPerVector, AScalarPerVector,
BScalarPerVector, BScalarPerVector,
CScalarPerVector>; CScalarPerVector>;
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType, using GridwiseBinSubtract = GridwiseBinaryElementwise_1D<CDataType,
CDataType, CDataType,
CDataType, CDataType,
CDataType, CDataType,
CGridDesc_M, CGridDesc_M,
CGridDesc_M, CGridDesc_M,
CGridDesc_M, CGridDesc_M,
Substract, Subtract,
MPerThread, MPerThread,
AScalarPerVector, AScalarPerVector,
BScalarPerVector, BScalarPerVector,
CScalarPerVector>; CScalarPerVector>;
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd, const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
CDataType, CDataType,
CDataType, CDataType,
CDataType, CDataType,
...@@ -593,14 +572,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -593,14 +572,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CGridDesc_M, CGridDesc_M,
CGridDesc_M, CGridDesc_M,
Add>; Add>;
const auto substract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubstract, const auto subtract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubtract,
CDataType, CDataType,
CDataType, CDataType,
CDataType, CDataType,
CGridDesc_M, CGridDesc_M,
CGridDesc_M, CGridDesc_M,
CGridDesc_M, CGridDesc_M,
Substract>; Subtract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
...@@ -653,7 +632,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -653,7 +632,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2 // c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
substract_kernel, subtract_kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -663,7 +642,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -663,7 +642,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_m_, arg.c_grid_desc_m_,
arg.c_grid_desc_m_, arg.c_grid_desc_m_,
arg.c_grid_desc_m_, arg.c_grid_desc_m_,
Substract{}); Subtract{});
ave_time += ave_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -764,7 +743,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -764,7 +743,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2 // c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
substract_kernel, subtract_kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -774,7 +753,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -774,7 +753,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_m_, arg.c_grid_desc_m_,
arg.c_grid_desc_m_, arg.c_grid_desc_m_,
arg.c_grid_desc_m_, arg.c_grid_desc_m_,
Substract{}); Subtract{});
ave_time += ave_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2, ...]
// B[N0, N1, N2, ..., K0, K1, K2, ...]
// D[M0, M1, M2, ..., N0, N1, N2, ...]
// E[M0, M1, M2, ..., N0, N1, N2, ...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_ms_ks_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatDsPointer,
typename FloatE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2, ...]
// B[N0, N1, N2, ..., K0, K1, K2, ...]
// D[M0, M1, M2, ..., N0, N1, N2, ...]
// E[M0, M1, M2, ..., N0, N1, N2, ...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceContractionMultipleD_Xdl_CShuffle
: public DeviceContractionMultipleD<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_ms_ks_lengths_vec,
const std::vector<index_t>& a_ms_ks_strides_vec)
{
assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK &&
a_ms_ks_strides_vec.size() == NumDimM + NumDimK);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
a_grid_desc_ms_ks,
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
make_tuple(mDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0);
const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_ns_ks_lengths_vec,
const std::vector<index_t>& b_ns_ks_strides_vec)
{
assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK &&
b_ns_ks_strides_vec.size() == NumDimN + NumDimK);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number<NumDimN + NumDimK>{});
const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number<NumDimN + NumDimK>{});
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
// dimension Ids for K0, K1, ...
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const auto b_grid_desc_ns_ks =
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
b_grid_desc_ns_ks,
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0);
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_vec,
const std::vector<index_t>& e_ms_ns_strides_vec)
{
assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN &&
e_ms_ns_strides_vec.size() == NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto num) {
return generate_tuple([&](auto i) { return vec[i]; }, num);
};
const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number<NumDimM + NumDimN>{});
const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number<NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const auto e_grid_desc_ms_ns =
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
e_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto MRaw = e_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = e_grid_desc_mraw_nraw.GetLength(I1);
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return e_grid_desc_mraw_nraw;
}
}
using AGridDesc_AK0_M_AK1 =
decltype(MakeAGridDescriptor_AK0_M_AK1(std::vector<index_t>{}, std::vector<index_t>{}));
using BGridDesc_BK0_N_BK1 =
decltype(MakeBGridDescriptor_BK0_N_BK1(std::vector<index_t>{}, std::vector<index_t>{}));
using EGridDesc_M_N =
decltype(MakeEGridDescriptor_M_N(std::vector<index_t>{}, std::vector<index_t>{}));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
std::vector<index_t> a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_ms_ns_lengths, a_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_ns_ks_lengths, b_ns_ks_strides)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_mz_stride_{},
a_kz_stride_{},
b_nz_stride_{},
b_kz_stride_{},
ds_nz_stride_{},
e_nz_stride_{}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
}
// for sanity check of vector memory access
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1];
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1];
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1];
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1];
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1];
}
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1];
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
index_t a_mz_stride_;
index_t a_kz_stride_;
index_t b_nz_stride_;
index_t b_kz_stride_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_mz_stride_;
index_t e_nz_stride_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
return false;
}
// check vector access
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!");
// vector memory access of A: could be on M or AK1 dimension
if constexpr(ABlockTransferSrcVectorDim == 1)
{
if(!(arg.a_mz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
// vector memory access of B: could be on N or BK1 dimension
if constexpr(BBlockTransferSrcVectorDim == 1)
{
if(!(arg.b_nz_stride_ == 1 &&
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool valid_d_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 &&
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{
valid_d_access = false;
}
});
if(valid_d_access == false)
{
return false;
}
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
a_ms_ns_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
ds_ms_ns_lengths,
ds_ms_ns_strides,
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_ms_ns_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_ms_ns_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
ds_ms_ns_lengths,
ds_ms_ns_strides,
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceContractionMultipleD_Xdl_CShuffle"
<< "<"
<< NumDimM << ", "
<< NumDimN << ", "
<< NumDimK << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< ABlockTransferSrcVectorDim << ", "
<< BBlockTransferSrcVectorDim
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#ifndef DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP // SPDX-License-Identifier: MIT
#define DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp"
#include "device_base.hpp" #include "ck/utility/common_header.hpp"
#include "device_conv_backward_weight.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "convolution_forward_specialization.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "common_header.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp"
#include "tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -773,4 +777,3 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -773,4 +777,3 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
#ifndef DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP // SPDX-License-Identifier: MIT
#define DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp"
#include "device_base.hpp" #include "ck/utility/common_header.hpp"
#include "device_conv_bwd_data.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "convolution_backward_data_specialization.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "common_header.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -821,4 +824,3 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -821,4 +824,3 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP // SPDX-License-Identifier: MIT
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp"
#include "device_base.hpp" #include "ck/utility/common_header.hpp"
#include "device_conv_fwd_bias_activation_add.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "convolution_forward_specialization.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "common_header.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp"
#include "tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp"
#include "gridwise_gemm_xdlops_v3r3.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -963,4 +966,3 @@ struct ...@@ -963,4 +966,3 @@ struct
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment