Unverified Commit 29deceb6 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #18 from ROCmSoftwarePlatform/merge-from-public

Merge from public
parents 91c1d147 c997bbf6
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
/**
* Transfer that uses direct load instructions to copy data from global to LDS memory.
*
* Traditional loads first copy data from global to registers, and then from registers to LDS.
* Direct loads do not need an intermediate step, data is copied directly from global to LDS,
* without the use of additional registers.
*
* However, the instruction has limitations:
* - each thread must copy exactly a single DWORD - 4 bytes;
* - threads within a single wavefront must write consecutive DWORDS into LDS,
* (data in global do not need to be contiguous, each thread might have its own offset).
*
* To make sure that all the transfers finished, the `waitcnt` instruction must be used with
* `vmcnt` instead of `lgkmcnt`.
*
* Limitations of the transfer class:
* - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight;
* - `DstVectorDim` must be the last dimension;
* - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1;
* - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B
* (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16,
* `ScalarPerVector` must be 2);
* - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be
* the same dimension;
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
* to guarantee that.
*/
template <typename ThreadGroup,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t ScalarPerVector>
struct ThreadGroupTensorSliceTransfer_DirectLoad
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static __device__ constexpr bool AreThreadClusterLengthsValid()
{
// Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to
// LDS by the threads from a single wavefront.
// Examples (assuming 64 threads in a wavefront, 128 in a thread block):
// 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
// data type = fp32 -> ScalarPerVector = 1
// INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31
// write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of
// [0, 4, 0].
// VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration,
// threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs).
// 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
// data type = fp16 -> ScalarPerVector = 2
// NOTE: ThreadClusterLengths must take into account that each thread writes two
// elements (single DWORD) along the contiguous dimension.
// INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write
// 8 * 2 elements of K1PerBlock and there are only 8;
// ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31
// write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32
// writes [1, 0, 0] instead of [0, 8, 0].
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
// elements = 64 consecutive DWORDs.
int num_contiguous_dwords = 1;
bool is_contiguous = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(is_contiguous)
{
num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1];
}
if(thread_slice_lengths[nDim - i - 1] > 1)
{
is_contiguous = false;
}
});
constexpr index_t wavefront_size = get_warp_size();
const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0;
bool thread_slice_lengths_correct = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(thread_slice_lengths[i] <= 0)
{
thread_slice_lengths_correct = false;
}
});
return wave_contiguous && thread_slice_lengths_correct;
}
__device__ constexpr ThreadGroupTensorSliceTransfer_DirectLoad(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
{
static_assert(ck::is_same_v<SrcData, DstData>,
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");
static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
static_assert(bytes_per_thread_load == dword_bytes,
"Direct load transfer requires each thread to load exactly a single "
"DWORD of data.");
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size(),
"Inconsistent number of dimensions across lengths and descriptors.");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths.");
static_assert(
AreThreadClusterLengthsValid(),
"Thread cluster lengths are incorrect. They must be set in a way that allows a single "
"wavefront to write contiguous DWORDs into LDS memory. ");
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
src_slice_origin_ = src_slice_origin_idx;
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
dst_slice_origin_ = dst_slice_origin_idx;
}
__device__ void ResetDstSliceWindow(const DstDesc& dst_desc)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_);
}
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"Destination data must be stored in an LDS memory buffer.");
static_assert(
ck::is_same_v<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>,
"SrcBuffer and SrcData data types must be consistent.");
static_assert(
ck::is_same_v<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>,
"DstBuffer and DstData data types must be consistent.");
constexpr auto dst_access_lengths = thread_slice_lengths;
const auto dst_forward_steps = generate_steps(dst_desc, 1);
const auto dst_backward_steps = generate_steps(dst_desc, -1);
const auto src_forward_steps = generate_steps(src_desc, 1);
const auto src_backward_steps = generate_steps(src_desc, -1);
// Loop over the destination block and copy data.
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
const auto src_offset = src_coord_.GetOffset();
const auto dst_offset = dst_coord_.GetOffset();
// Check if src data is not in the logic padding area.
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, is_src_valid);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// Decide whether to move forward or backward.
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
}
else
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
}
}
});
});
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow(dst_desc);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
src_slice_origin_ = src_slice_origin_ + step;
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
}
template <typename DescType>
__device__ auto generate_steps(const DescType& desc, int sign)
{
return generate_tuple(
[&](auto i) {
Index step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0;
});
return make_tensor_coordinate_step(desc, step_idx);
},
Number<nDim>{});
}
private:
static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{});
SrcCoord src_coord_;
DstCoord dst_coord_;
Index src_slice_origin_;
Index dst_slice_origin_;
};
} // namespace ck
......@@ -33,7 +33,8 @@ template <index_t NumDimM,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <array>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t NumDim>
struct DeviceElementwise : public BaseOperator
{
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op,
UnaryOperation unary_op,
Scale scale_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; // namespace device
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t NumDim>
using DeviceElementwisePtr = std::unique_ptr<DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
UnaryOperation,
Scale,
NumDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
/**
* \brief Grouped Convolution Forward
*
* \details
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
* output : output image E[G, N, K, Ho, Wo]
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputeType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>())> // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
{
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
// If DataType is tuple, user has to pass std::array with pointers.
using APointers =
std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
using BPointers =
std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
/**
* \brief Make argument pointer for grouped conv fwd.
*
* \param p_a A pointer to the input (std::array<const void*, NumA> with
pointers for multiple A).
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
pointers for multiple B).
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the output.
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Input left paddings.
* \param input_right_pads Input right paddings.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
APointers p_a,
BPointers p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -3,21 +3,33 @@
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
/**
* \brief Grouped Convolution Forward
*
* \note This structure is deprecated (left for backwards compatibility). Please use
* DeviceGroupedConvFwdMultipleABD.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
......@@ -30,36 +42,25 @@ template <index_t NDimSpatial,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputeType = ADataType>
struct DeviceGroupedConvFwdMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, // input image
const void* p_b, // weight
const std::array<const void*, NumDTensor>& p_ds,
void* p_e, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
typename ComputeType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>())> // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
using DeviceGroupedConvFwdMultipleD = DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputeType>;
} // namespace device
} // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceNormalizationBwdGammaBeta : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> dgammaStrides,
const std::vector<index_t> dbetaStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_mean,
const void* p_invStd,
void* p_dgamma,
void* p_dbeta) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationBwdGammaBetaPtr =
std::unique_ptr<DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -19,7 +19,7 @@ template <typename XDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
struct DeviceNormalization : public BaseOperator
struct DeviceNormalizationFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
......@@ -50,14 +50,14 @@ template <typename XDataType,
typename YElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>>;
using DeviceNormalizationFwdPtr = std::unique_ptr<DeviceNormalizationFwd<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -263,19 +263,18 @@ struct DeviceColumnToImageImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
InputGridDesc{}))>;
using GridwiseTensorRearrangeKernel =
GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Add,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>>;
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Add,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<>>;
struct Argument : public BaseArgument
{
......@@ -453,7 +452,7 @@ struct DeviceColumnToImageImpl
std::vector<const InputDataType*> p_in_container_;
std::vector<OutputDataType*> p_out_container_;
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
};
struct Invoker : public BaseInvoker
......@@ -471,7 +470,7 @@ struct DeviceColumnToImageImpl
OutputGridDesc,
OutputDataType,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>,
ComputePtrOffsetOfStridedBatch<>,
GridwiseTensorRearrangeKernel>;
// Execute each set of independent filters
......
......@@ -385,9 +385,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// desc for blockwise copy
using AsGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(AsGridDesc_M_K{}))>;
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(
AsGridDesc_M_K{}))>;
using BsGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(BsGridDesc_N_K{}))>;
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(
BsGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
......@@ -397,7 +399,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeBlock2ETileMap(EGridDesc_M_N{}))>;
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
......@@ -429,7 +431,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
bs_grid_desc_bk0_n_bk1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
......@@ -481,10 +483,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
block_2_etile_map_))
{
as_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
bs_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......
......@@ -145,7 +145,8 @@ template <index_t NumDimM,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
typename ComputeDataType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceContractionMultipleD_Xdl_CShuffle
: public DeviceContractionMultipleD<NumDimM,
NumDimN,
......@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
......@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......@@ -595,7 +595,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return false;
}
if(ck::get_device_name() != "gfx90a" && std::is_same<ADataType, double>::value)
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
std::is_same<ADataType, double>::value)
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim_m, // choose how to set dims
index_t NumDim_n,
index_t NumDim_k,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n + NumDim_k>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n + NumDim_k;
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
}
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
}
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MNK>
static auto PadDescriptor_MNK(Desc_MNK desc_mnk,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
{
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mnk.GetLength(I0);
const auto n = desc_mnk.GetLength(I1);
const auto k = desc_mnk.GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k;
const auto desc_mnk_pad =
transform_tensor_descriptor(desc_mnk,
make_tuple(make_right_pad_transform(m, pad_m),
make_right_pad_transform(n, pad_n),
make_right_pad_transform(k, pad_k)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return desc_mnk_pad;
}
static auto MakeDescriptor_MNK(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDim_m + NumDim_n, NumDim, 1>::type();
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
const auto kLengths = get_container_subset(tupleOfShape, kDimIds);
// merge nd to 3d desc - [s0 * s1 * ...]
if constexpr(NumDim > 3)
{
const auto desc_mnk = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(mLengths),
make_merge_transform(nLengths),
make_merge_transform(kLengths)),
make_tuple(mDimIds, nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadDescriptor_MNK(
desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
}
else
return PadDescriptor_MNK(
desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
}
template <index_t TupleSize>
static auto GenerateInOutGrid3dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 3)
{
return MakeDescriptor_MNK({1, 1, 1}, {1, 1, 1}, 1, 1, 1, 1, 1);
}
else
{
return MakeDescriptor_MNK({1}, {1}, 1, 1, 1, 1, 1);
};
},
Number<TupleSize>{});
}
using OutGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumOutput>{}));
using InGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumInput>{}));
using GridwiseElementwise = GridwiseElementwise_3D<InGrid3dDescTuple,
OutGrid3dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
KPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
static_assert(NumDim_k > 0, "");
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config) * arg.blockSize_;
index_t num_threads_m = gridSize / (16 * 16);
index_t num_threads_n = 16;
index_t num_threads_k = 16;
auto in_grid_3d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MNK(arg.lengths_,
arg.inStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
},
Number<NumInput>{});
auto out_grid_3d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MNK(arg.lengths_,
arg.outStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_3d<GridwiseElementwise,
InGrid3dDescTuple,
OutGrid3dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gridSize),
dim3(arg.blockSize_),
0,
in_grid_3d_desc_tuple,
out_grid_3d_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
num_threads_m,
num_threads_n,
num_threads_k);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
return true;
}
if(strides[vectorDim] >= scalarPerVector)
{
return true;
}
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1);
});
return valid;
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t NumDim,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
UnaryOperation,
Scale,
NumDim>
{
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
constexpr auto I0 = Number<0>{};
const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
}
static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(NumDim > 1)
{
const auto desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
else
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
template <index_t TupleSize>
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 1)
{
return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
}
else
{
return MakeDescriptor_M({1}, {1}, 1, 1);
};
},
Number<TupleSize>{});
};
using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumInput>{}));
using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumOutput>{}));
using GridwiseElementwise = GridwiseElementwise_1D<InGrid1dDescTuple,
OutGrid1dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
UnaryOperation,
Scale,
MPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op,
UnaryOperation unary_op,
Scale scale_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
unary_op_(unary_op),
scale_op_(scale_op),
blockSize_(256)
{
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
UnaryOperation unary_op_;
Scale scale_op_;
index_t blockSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config);
auto in_grid_1d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_);
},
Number<NumInput>{});
auto out_grid_1d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
InGrid1dDescTuple,
OutGrid1dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
UnaryOperation,
Scale>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gridSize),
dim3(arg.blockSize_),
0,
in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
arg.unary_op_,
arg.scale_op_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector) {
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0)
return true;
if(strides.back() != 1 && scalarPerVector == 1)
return true;
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false;
});
return valid;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op,
UnaryOperation unary_op,
Scale scale_op)
{
return Argument{lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op,
unary_op,
scale_op};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op,
UnaryOperation unary_op,
Scale scale_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op,
unary_op,
scale_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -305,9 +305,11 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
// desc for blockwise copy
using AsGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(AsGridDesc_M_K{}))>;
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(
AsGridDesc_M_K{}))>;
using BsGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(BsGridDesc_N_K{}))>;
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(
BsGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
......@@ -317,7 +319,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeBlock2ETileMap(EGridDesc_M_N{}))>;
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
......@@ -349,7 +351,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
bs_grid_desc_bk0_n_bk1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
......@@ -407,10 +409,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
block_2_etile_map_))
{
as_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
bs_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferScalarPerVector,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferScalarPerVector,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v4,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
: public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr auto I1 = Number<1>{};
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferScalarPerVector,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferScalarPerVector,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load<
GridwiseGemm,
ADataType,
BDataType,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
typename GridwiseGemm::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if(!ck::is_lds_direct_load_supported())
{
return false;
}
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// Check vector load of Ds.
// For now, only the RowMajor layout is supported.
bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
if(!all_valid)
{
return false;
}
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideDs,
StrideE,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideDs,
StrideE,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}};
// clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferScalarPerVector << ", "
<< BBlockTransferScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// clang-format off
str << "DeviceGemm_Xdl_CShuffle"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
......@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];;
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferScalarPerVector,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferScalarPerVector,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v4,
typename ComputeDataType = EDataType>
struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
BLayout,
ELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr auto I1 = Number<1>{};
using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad<
ALayout,
BLayout,
ck::Tuple<>,
ELayout,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferScalarPerVector,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferScalarPerVector,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load<
GridwiseGemm,
ADataType,
BDataType,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
typename GridwiseGemm::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
// arg.p_ds_grid_,
ck::Tuple<>{},
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if(!ck::is_lds_direct_load_supported())
{
return false;
}
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
using EmptyDsPointers = std::array<const void*, 0>;
using EmptyDsStrides = std::array<ck::index_t, 0>;
return Argument{p_a,
p_b,
EmptyDsPointers{},
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
EmptyDsStrides{},
StrideE,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
using EmptyDsPointers = std::array<const void*, 0>;
using EmptyDsStrides = std::array<ck::index_t, 0>;
return std::make_unique<Argument>(p_a,
p_b,
EmptyDsPointers{},
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
EmptyDsStrides{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}};
// clang-format off
str << "DeviceGemm_Xdl_CShuffle_LdsDirectLoad"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferScalarPerVector << ", "
<< BBlockTransferScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -59,7 +59,8 @@ template <typename ADataType,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename ComputeType = CDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BLayout,
......@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
......@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t K0Padded_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
......@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
MPadded_,
NPadded_,
KPadded_,
K0_,
K0Padded_,
k_batch_),
a_element_op(a_element_op_),
b_element_op(b_element_op_),
......@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto b2c_map = DefaultBlock2CTileMap{};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto K0 = karg.K0;
const auto K0Padded = karg.K0Padded;
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
float ave_time = 0;
......@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch,
a_element_op,
b_element_op,
......@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch,
a_element_op,
b_element_op,
......@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
// polymorphic
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched]
<< ", PipelineVersion: " << PipelineVersionToString[PipelineVer];
return str.str();
}
};
} // namespace device
......
......@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOp a_element_op_;
......@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>;
return launch_and_time_kernel(
......
......@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::vector<Block2ETileMap> block_2_etile_map_container_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOp a_element_op_;
......@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>;
return launch_and_time_kernel(
......
......@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
// element-wise op
OutElementwiseOperation a_element_op_;
......@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I0>,
ComputePtrOffsetOfStridedBatch<>,
has_main_loop,
has_double_loop>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment