Commit e70a4d19 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents ce72f286 0dacd895
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise1dFunctor,
typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale>
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op,
const Scale scale_op)
{
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
unary_op,
scale_op);
}
template <typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_1D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid1dDescTuple::Size() &&
NumOutput == OutGrid1dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op,
const Scale scale_op)
{
const index_t thread_global_id = get_thread_global_1d_id();
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
decltype(in_grid_1d_desc_tuple[I]),
decltype(thread_buffer_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(
I), // ScalarPerVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc_tuple[I],
thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<DataType,
DataType,
decltype(thread_buffer_desc_m),
decltype(out_grid_1d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter = M / (loop_step);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_m,
make_tuple(I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
loop_step_index);
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
// get reference to in data
auto uop_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
// get reference to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); },
Number<NumOutput>{});
unpack2(unary_op, uop_data_refs, uop_data_refs);
auto sop_in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
auto sop_out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_m,
make_tuple(I0),
out_thread_buf_tuple[I],
out_grid_1d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
loop_step_index);
});
} while(--num_iter);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise3dFunctor,
typename InGrid3dDescTuple,
typename OutGrid3dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation>
__global__ void kernel_elementwise_3d(const InGrid3dDescTuple in_grid_3d_desc_tuple,
const OutGrid3dDescTuple out_grid_3d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
{
GridwiseElementwise3dFunctor::Run(in_grid_3d_desc_tuple,
out_grid_3d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
num_threads_m,
num_threads_n,
num_threads_k);
}
template <typename InGrid3dDescTuple,
typename OutGrid3dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_3D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid3dDescTuple::Size() &&
NumOutput == OutGrid3dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto thread_buffer_desc_mnk = make_naive_tensor_descriptor_packed(
make_tuple(Number<MPerThread>{}, Number<NPerThread>{}, Number<KPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid3dDescTuple in_grid_3d_desc_tuple,
const OutGrid3dDescTuple out_grid_3d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
{
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread * KPerThread,
true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread * KPerThread,
true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_3d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_3d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto M = in_grid_3d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_3d_desc_tuple[I0].GetLength(I1);
const auto K = in_grid_3d_desc_tuple[I0].GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
const index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k);
const index_t tid_nk = thread_1d_id % (num_threads_n * num_threads_k);
const index_t tid_n = tid_nk / num_threads_k;
const index_t tid_k = tid_nk % num_threads_k;
const auto thread_global_offset =
make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
decltype(in_grid_3d_desc_tuple[I]),
decltype(thread_buffer_desc_mnk),
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
01, // SrcVectorDim
InScalarPerVectorSeq::At(I), // InScalarPerVectorSeq::At(I), //
// ScalarPerVector
1, // SrcScalarStrideInVector
true>{in_grid_3d_desc_tuple[I], thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<
DataType,
DataType,
decltype(thread_buffer_desc_mnk),
decltype(out_grid_3d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
OutScalarPerVectorSeq::At(I), // OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
true>(out_grid_3d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter_m = M / (loop_step_m);
do
{
index_t num_iter_n = N / (loop_step_n);
do
{
index_t num_iter_k = K / (loop_step_k);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_3d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) {
static_for<0, KPerThread, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_mnk.CalculateOffset(make_tuple(iM, iN, iK));
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& {
return in_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumInput>{});
// get referenec to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& {
return out_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
});
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
out_thread_buf_tuple[I],
out_grid_3d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
} while(--num_iter_k);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_n);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_m);
}
};
} // namespace ck
...@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// A desc for source in blockwise copy // A desc for source in blockwise copy
template <typename AGridDesc_M_K> template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{ {
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
...@@ -219,17 +219,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -219,17 +219,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename AsGridDesc_M_K> template <typename AsGridDesc_M_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, [&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
Number<NumATensor>{}); Number<NumATensor>{});
} }
// B desc for source in blockwise copy // B desc for source in blockwise copy
template <typename BGridDesc_N_K> template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{ {
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1); const auto K = b_grid_desc_n_k.GetLength(I1);
...@@ -245,10 +245,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -245,10 +245,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BsGridDesc_N_K> template <typename BsGridDesc_N_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, [&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
Number<NumBTensor>{}); Number<NumBTensor>{});
} }
...@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// return block_id to E matrix tile idx (m0, n0) mapping // return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N> template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n); e_grid_desc_m_n);
...@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{}); Number<NumATensor>{});
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock, ThisThreadBlock,
AsDataType, AsDataType,
...@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{}); Number<NumBTensor>{});
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock, ThisThreadBlock,
BsDataType, BsDataType,
...@@ -1005,9 +1011,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -1005,9 +1011,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE); const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k); const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k); const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n); MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load(
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AComputeDataType_,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferScalarPerVector,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferScalarPerVector,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v4,
typename BComputeDataType = AComputeDataType_>
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else
using AComputeDataType = AComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle.
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max(
NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) +
NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType),
c_block_size * sizeof(CShuffleDataType));
}
__host__ __device__ static auto
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); },
Number<NumDTensor>{});
}
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// A desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy.
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
using Block2ETileMap = remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
static_assert(
std::is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough> &&
std::is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>,
"Direct load transfers do not support elementwise operations other than passthrough.");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// Check the consistency of descriptors.
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// Check the tile size.
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// Check gridwise gemm pipeline.
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// Check block-to-E-tile.
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// Check tensor size: cannot exceed 2GB.
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DsGridPointer = decltype(MakeDsGridPointer());
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <typename DataType>
__device__ static auto AllocateBlockBuffers(void* p_shared,
int32_t num_elems,
int32_t offset_elems,
int32_t max_lds_align)
{
const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align);
return generate_tuple(
[&](auto i) {
const int32_t local_offset = i * single_buffer_offset;
return make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + local_offset + offset_elems, num_elems);
},
Number<NumGemmKPrefetchStage>{});
}
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
// Elementwise operations are not supported for A and B, arguments left only for the API
// consistency.
(void)a_element_op;
(void)b_element_op;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// Divide block work by [M, N].
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// This forces m/n_block_data_idx_on_grid into SGPR.
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, destination of blockwise copy.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, destination of blockwise copy.
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ADataType,
AComputeDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcVectorDim,
2,
ABlockTransferScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BDataType,
BComputeDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcVectorDim,
2,
BBlockTransferScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buffers = AllocateBlockBuffers<AComputeDataType>(
p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align);
const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage;
auto b_block_buffers =
AllocateBlockBuffers<BComputeDataType>(p_shared,
b_block_desc_bk0_n_bk1.GetElementSpaceSize(),
b_buffers_offset,
max_lds_align);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buffers,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buffers,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// Shuffle C and write out.
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// Calculate the origin of thread output tensor on global memory.
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Shuffle: threadwise copy C from VGPR to LDS.
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// A tuple of reference to C/Ds tensor descriptors.
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// A tuple of reference to C/Ds grid buffers.
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// A tuple of starting index of C/Ds blockwise copy.
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// Blockwise copy C/D/E between LDS and global.
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
// Space filling curve for threadwise C in VGPR before shuffle.
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// Space filling curve for shuffled blockwise C/D/E.
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// Make sure it's safe to write to LDS.
block_sync_lds();
// Each thread write its data from VGPR to LDS.
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// Make sure it's safe to read from LDS.
block_sync_lds();
// Each block copy its data from LDS to global.
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// Move on Ds.
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// Move on E.
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
struct Argument : public tensor_operation::device::BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
a_grid_desc_ak0_m_ak1_{MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw}
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
});
if(CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
}
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// Pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// Tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// Tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise ops
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// For checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
};
};
} // namespace ck
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp"
namespace ck { namespace ck {
...@@ -14,6 +15,8 @@ enum struct PipelineVersion ...@@ -14,6 +15,8 @@ enum struct PipelineVersion
{ {
v1, v1,
v2, v2,
// v3 is only used in the Stream-K implementation.
v4,
}; };
template <PipelineVersion PipelineVer, template <PipelineVersion PipelineVer,
...@@ -36,6 +39,10 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -36,6 +39,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{ {
return GridwiseGemmPipeline_v2{}; return GridwiseGemmPipeline_v2{};
} }
else if constexpr(PipelineVer == PipelineVersion::v4)
{
return GridwiseGemmPipeline_v4<NumPrefetch>{};
}
else else
{ {
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace lds_direct_load {
__device__ void sched_barrier()
{
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
// When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of
// `sched_barrier` is necessary to make sure no instructions that use the loaded memory
// are scheduled by the compiler before the `waitcnt` instruction.
__builtin_amdgcn_sched_barrier(0);
#endif
}
} // namespace lds_direct_load
namespace ck {
template <index_t NumPrefetch>
struct GridwiseGemmPipeline_v4;
// 1-stage prefetch
template <>
struct GridwiseGemmPipeline_v4<1>
{
static constexpr auto I0 = Number<0>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffers,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffers,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffers& a_block_bufs,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffers& b_block_bufs,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1);
auto& a_block_buf = a_block_bufs.At(I0);
auto& b_block_buf = b_block_bufs.At(I0);
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds_direct_load();
lds_direct_load::sched_barrier();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds_direct_load();
lds_direct_load::sched_barrier();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds_direct_load();
lds_direct_load::sched_barrier();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
// 2-stages prefetch
template <>
struct GridwiseGemmPipeline_v4<2>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return (num_loop / 2) > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffers,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffers,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffers& a_block_bufs,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffers& b_block_bufs,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2);
auto& a_block_buf1 = a_block_bufs.At(I0);
auto& a_block_buf2 = a_block_bufs.At(I1);
auto& b_block_buf1 = b_block_bufs.At(I0);
auto& b_block_buf2 = b_block_bufs.At(I1);
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds_direct_load();
lds_direct_load::sched_barrier();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf);
block_sync_lds_direct_load();
lds_direct_load::sched_barrier();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf);
i += 2;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds_direct_load();
lds_direct_load::sched_barrier();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf);
block_sync_lds_direct_load();
lds_direct_load::sched_barrier();
blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf);
}
}
};
} // namespace ck
...@@ -996,6 +996,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -996,6 +996,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
} }
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.K0 % K0PerBlock == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
if(problem.K % ABlockTransferSrcScalarPerVector != 0) if(problem.K % ABlockTransferSrcScalarPerVector != 0)
......
...@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t MPadded; index_t MPadded;
index_t NPadded; index_t NPadded;
index_t KPadded; index_t KPadded;
index_t K0; index_t K0Padded;
index_t k_batch; index_t k_batch;
Argument(const FloatA* p_a_grid_, Argument(const FloatA* p_a_grid_,
...@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t MPadded_, index_t MPadded_,
index_t NPadded_, index_t NPadded_,
index_t KPadded_, index_t KPadded_,
index_t K0_, index_t K0Padded_,
index_t k_batch_) index_t k_batch_)
: p_a_grid(p_a_grid_), : p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_), p_b_grid(p_b_grid_),
...@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MPadded(MPadded_), MPadded(MPadded_),
NPadded(NPadded_), NPadded(NPadded_),
KPadded(KPadded_), KPadded(KPadded_),
K0(K0_), K0Padded(K0Padded_),
k_batch(k_batch_) k_batch(k_batch_)
{ {
} }
...@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<< "MP:" << MPadded << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", " << "KP:" << KPadded << ", "
<< "K0:" << K0 << ", " << "K0Padded:" << K0Padded << ", "
<< "KB:" << k_batch << "}" << std::endl; << "KB:" << k_batch << "}" << std::endl;
} }
}; };
...@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return math::integer_least_multiple(N, NPerBlock); return math::integer_least_multiple(N, NPerBlock);
} }
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
{ {
// k_batch * k0 * k0_per_block * k1 // k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1; auto K_t = K_Batch * K0PerBlock * K1;
...@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{ {
auto K0 = CalculateK0(K, K_Batch); auto K0Padded = CalculateK0Padded(K, K_Batch);
return K_Batch * K0 * K1; return K_Batch * K0Padded * K1;
} }
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
...@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t KBatch, index_t KBatch,
index_t K0, index_t K0Padded,
index_t KPad) index_t KPad)
{ {
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
...@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(M, MPad - M)), make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t N, index_t N,
index_t StrideB, index_t StrideB,
index_t KBatch, index_t KBatch,
index_t K0, index_t K0Padded,
index_t KPad) index_t KPad)
{ {
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
...@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(N, NPad - N)), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return false; return false;
} }
} }
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
...@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = karg.k_batch * K0PerBlock * K1;
if(!(karg.K % K_t == 0))
{
#if DEBUG_LOG
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
} }
...@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{ {
#if DEBUG_LOG #if DEBUG_LOG
std::cout std::cout << "Arg N (" << karg.N
<< "Arg N (" << karg.N << ") value is not a multiple of "
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" "CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
...@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{ {
#if DEBUG_LOG #if DEBUG_LOG
std::cout std::cout << "Arg M (" << karg.M
<< "Arg M (" << karg.M << ") value is not a multiple of "
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" "CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
} }
} }
const auto num_k_loop = karg.K0 / K0PerBlock; const auto num_k_loop = karg.K0Padded / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
#if DEBUG_LOG #if DEBUG_LOG
std::cout << "The number of k loops (" << num_k_loop std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline." << ") value is not supported by GridwiseGemm Pipeline."
<< " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__ << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
} }
...@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch) __host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{ {
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; const index_t K0Padded =
const index_t KPad = KBatch * K0 * K1; math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0Padded * K1;
return KPad; return KPad;
} }
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
{ {
const index_t num_loop = K0 / K0PerBlock; const index_t num_loop = K0Padded / K0PerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const FloatB* p_b_grid = karg.p_b_grid; const FloatB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid; FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
......
...@@ -21,6 +21,7 @@ template <typename InputGridDesc, ...@@ -21,6 +21,7 @@ template <typename InputGridDesc,
typename OutputGridDesc, typename OutputGridDesc,
typename OutputDataType, typename OutputDataType,
typename Block2ETileMap, typename Block2ETileMap,
typename ComputePtrOffsetOfStridedBatch,
typename GridwiseTensorRearrangeKernel> typename GridwiseTensorRearrangeKernel>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -30,13 +31,20 @@ __global__ void ...@@ -30,13 +31,20 @@ __global__ void
const InputDataType* __restrict__ p_in_global, const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc, const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global, OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map) const index_t batch_count,
const Block2ETileMap block_2_tile_map,
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel::Run( GridwiseTensorRearrangeKernel::Run(in_grid_desc,
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map); p_in_global,
out_grid_desc,
p_out_global,
batch_count,
block_2_tile_map,
compute_ptr_offset_of_batch);
#else #else
ignore = in_grid_desc; ignore = in_grid_desc;
ignore = p_in_global; ignore = p_in_global;
...@@ -56,7 +64,8 @@ template <typename InputGridDesc, ...@@ -56,7 +64,8 @@ template <typename InputGridDesc,
typename ThreadClusterLengths, typename ThreadClusterLengths,
index_t ScalarPerVector, index_t ScalarPerVector,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename Block2ETileMap> typename Block2ETileMap,
typename ComputePtrOffsetOfStridedBatch>
struct GridwiseTensorRearrange struct GridwiseTensorRearrange
{ {
...@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange ...@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
const InputDataType* __restrict__ p_in_global, const InputDataType* __restrict__ p_in_global,
const OutputGridDesc& out_grid_desc, const OutputGridDesc& out_grid_desc,
OutputDataType* __restrict__ p_out_global, OutputDataType* __restrict__ p_out_global,
const Block2ETileMap& block_2_tile_map) const index_t batch_count,
const Block2ETileMap& block_2_tile_map,
const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
{ {
const auto block_work_idx = const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange ...@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
const index_t k_block_data_idx_on_grid = const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
// Global Memory
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
auto copy_global_to_global = auto copy_global_to_global =
ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock,
Tuple<InputDataType>, Tuple<InputDataType>,
...@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange ...@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// Global Memory
const index_t a_batch_offset =
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const index_t c_batch_offset =
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
copy_global_to_global.Run( copy_global_to_global.Run(
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf)); tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace ck {
// dgamma = reduce_sum(dy * (x - mean) * inv_std)
// dbeta = reduce_sum(dy)
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DGammaDataType,
typename DBetaDataType,
typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t DYSrcVectorDim,
index_t DYSrcVectorSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t MeanInvStdSrcVectorDim,
index_t MeanInvStdSrcVectorSize,
index_t DGammaDstVectorSize,
index_t DBetaDstVectorSize>
struct GridwiseNormalizationBwdGammaBeta_mk_to_k
{
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder =
typename conditional<DYSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using XThreadBufferDimAccessOrder =
typename conditional<XSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using MeanInvStdThreadBufferDimAccessOrder =
typename conditional<MeanInvStdSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using BlockwiseSumReduce = PartitionedBlockwiseReduction<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
true>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& mean_grid_desc_m_k,
const GridDesc_M_K& inv_std_grid_desc_m_k,
const GridDesc_M& dgamma_grid_desc_m,
const GridDesc_M& dbeta_grid_desc_m,
index_t num_k_block_tile_iteration,
const DYDataType* const __restrict__ p_dy_global,
const XDataType* const __restrict__ p_x_global,
const MeanInvStdDataType* const __restrict__ p_mean_global,
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
DGammaDataType* const __restrict__ p_dgamma_global,
DBetaDataType* const __restrict__ p_dbeta_global)
{
// LDS
__shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
// Global
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
auto dgamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize());
auto dbeta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize());
// VGPR
auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto dgamma_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
auto dbeta_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
// IO
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
DYThreadBufferDimAccessOrder,
DYSrcVectorDim,
DYSrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
XThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_mean_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
true>(
mean_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_inv_std_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
true>(
inv_std_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dgamma_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
DGammaDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
DGammaDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dgamma_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbeta_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
DBetaDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
DBetaDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dbeta_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dgamma_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
dbeta_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_mean_load.Run(mean_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
mean_thread_buf);
threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
inv_std_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
inv_std_thread_buf);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
dgamma_thread_buf(offset_m) +=
dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
(x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]);
dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k];
});
});
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I));
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_dgamma_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dgamma_thread_buf,
dgamma_grid_desc_m,
dgamma_global_val_buf);
threadwise_dbeta_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbeta_thread_buf,
dbeta_grid_desc_m,
dbeta_global_val_buf);
}
}
};
} // namespace ck
...@@ -944,4 +944,51 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -944,4 +944,51 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
#endif #endif
} }
// Direct loads from global to LDS.
__device__ void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread>
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
T* lds_base_ptr,
const index_t lds_offset,
const bool is_valid,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset;
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes),
"s"(src_resource));
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
}
} // namespace ck } // namespace ck
...@@ -173,6 +173,26 @@ struct DynamicBuffer ...@@ -173,6 +173,26 @@ struct DynamicBuffer
} }
} }
template <typename DstBuffer, index_t NumElemsPerThread>
__host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
index_t src_offset,
index_t dst_offset,
bool is_valid_element) const
{
// Copy data from global to LDS memory using direct loads.
static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"Destination data must be stored in an LDS memory buffer.");
amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
src_offset,
dst_buf.p_data_,
dst_offset,
is_valid_element,
element_space_size_);
}
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value,
......
...@@ -19,6 +19,15 @@ __device__ void block_sync_lds() ...@@ -19,6 +19,15 @@ __device__ void block_sync_lds()
#endif #endif
} }
__device__ void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
}
__device__ void s_nop() __device__ void s_nop()
{ {
#if 1 #if 1
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "functional4.hpp" #include "functional4.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#include "is_detected.hpp"
namespace ck { namespace ck {
...@@ -33,6 +34,28 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& ...@@ -33,6 +34,28 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
ty); ty);
} }
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
// Support any number of tuples to concat (also 1)
template <typename... X>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
{
return concat_tuple(tx, concat_tuple(tuples...));
}
namespace detail { namespace detail {
template <typename F, typename X, index_t... Is> template <typename F, typename X, index_t... Is>
...@@ -78,4 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, ...@@ -78,4 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
} }
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
{
return element;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
__host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
{
return make_tuple(element);
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
{
if constexpr(Depth == MaxDepth)
{
return tuple;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(UnrollNestedTuple<Depth + 1, MaxDepth>(ts)...);
},
tuple);
}
}
template <typename... Ts>
__host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
return tuple.At(Idx{});
},
Number<Tuple<Ts...>::Size()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
{
static_assert(Idx < End, "Wrong parameters for TupleReduce");
if constexpr(Idx + 1 == End)
{
return tuple.At(Number<Idx>{});
}
else
{
return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts>
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{
return (is_detected<is_tuple, Ts>::value || ...);
}
} // namespace ck } // namespace ck
...@@ -95,11 +95,19 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -95,11 +95,19 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
// convert fp32 to fp8 // Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{ {
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union union
{ {
float fval; float fval;
...@@ -108,70 +116,139 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) ...@@ -108,70 +116,139 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
} val; } val;
val.fval = x; val.fval = x;
uint32_t ival = 0; uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival; val.i32val = ival;
return val.i8val[0]; return val.i8val[0]; // little endian
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr uint32_t rng = 0;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng); rng);
#endif #endif
} }
// convert fp8 to fp32 // convert fp16 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval; // convert to float and use native converion
uint32_t i32val = static_cast<uint32_t>(x); return f8_convert_sr<f8_t>(type_convert<float>(x));
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x); constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif #endif
} }
// convert fp16 to fp8 // convert fp32 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion // convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x)); return f8_convert_sr<bf8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
x, rng); rng);
#endif #endif
} }
// convert fp8 to fp16 // convert fp16 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16 // convert to float and use native converion
return type_convert<half_t>(type_convert<float>(x)); return f8_convert_rne<f8_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif #endif
} }
// convert fp32 to bf8 // convert fp32 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union union
...@@ -196,6 +273,116 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) ...@@ -196,6 +273,116 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
#endif #endif
} }
// convert fp16 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else
return f8_convert_rne<f8_t>(x);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
const vector_type<float, 2> f32x2_v(x);
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
f32x2_v.template AsType<float>()[Number<1>{}]);
return bit_cast<half2_t>(y);
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
#else
return f8_convert_rne<f8_t>(x);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
#endif
}
// convert bf8 to fp32 // convert bf8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
...@@ -216,17 +403,10 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) ...@@ -216,17 +403,10 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if CK_USE_SR_F8_CONVERSION
// convert to float and use native converion return f8_convert_sr<bf8_t>(x);
return type_convert<bf8_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; return f8_convert_rne<bf8_t>(x);
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif #endif
} }
...@@ -299,104 +479,4 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h ...@@ -299,104 +479,4 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
return bf16_convert_rtn<bhalf_t>(x_fp32); return bf16_convert_rtn<bhalf_t>(x_fp32);
} }
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
} // namespace ck } // namespace ck
...@@ -19,9 +19,7 @@ namespace host { ...@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for column to image. * \brief Reference implementation for column to image.
* *
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout. * Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout. * Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout. * \tparam ImageLayout Image Layout.
...@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 && if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 2)) arg.input_.GetNumOfDimension() == 3))
{ {
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1]; const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2]; const index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n) { auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo) for(index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Wo + wo; index_t row = n * Wo + wo;
...@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{ {
float v_in = ck::type_convert<float>(arg.input_(row, column)); float v_in =
float v_out = ck::type_convert<float>(arg.output_(0, n, c, wi)); ck::type_convert<float>(arg.input_(g, row, column));
arg.output_(0, n, c, wi) = float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi));
arg.output_(g, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out); ck::type_convert<OutDataType>(v_in + v_out);
} }
column++; column++;
...@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[0]; const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n) { auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho) for(index_t ho = 0; ho < Ho; ++ho)
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(index_t wo = 0; wo < Wo; ++wo)
...@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[4]) arg.output_.GetLengths()[4])
{ {
float v_in = float v_in =
ck::type_convert<float>(arg.input_(row, column)); ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>( float v_out = ck::type_convert<float>(
arg.output_(0, n, c, hi, wi)); arg.output_(g, n, c, hi, wi));
arg.output_(0, n, c, hi, wi) = arg.output_(g, n, c, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out); ck::type_convert<OutDataType>(v_in + v_out);
} }
column++; column++;
...@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[1]; const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n) { auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o) for(index_t d_o = 0; d_o < Do; ++d_o)
{ {
for(index_t ho = 0; ho < Ho; ++ho) for(index_t ho = 0; ho < Ho; ++ho)
...@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[5]) arg.output_.GetLengths()[5])
{ {
float v_in = ck::type_convert<float>( float v_in = ck::type_convert<float>(
arg.input_(row, column)); arg.input_(g, row, column));
float v_out = ck::type_convert<float>( float v_out = ck::type_convert<float>(
arg.output_(0, n, c, di, hi, wi)); arg.output_(g, n, c, di, hi, wi));
arg.output_(0, n, c, di, hi, wi) = arg.output_(g, n, c, di, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out); ck::type_convert<OutDataType>(v_in + v_out);
} }
column++; column++;
...@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) && if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX))) arg.input_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{ {
return false; return false;
} }
......
...@@ -23,6 +23,7 @@ template <ck::index_t NumDimM, ...@@ -23,6 +23,7 @@ template <ck::index_t NumDimM,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType, typename AccDataType,
typename ComputeDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false> ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
...@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{ {
for(ck::index_t k1 = 0; k1 < K1; ++k1) for(ck::index_t k1 = 0; k1 < K1; ++k1)
{ {
// Simulate the possible casting when ComputeDataType is different than the
// A/B data types
ComputeDataType v_a_compute_input =
ck::type_convert<ComputeDataType>(arg.a_ms_ks_(m0, m1, k0, k1));
ComputeDataType v_b_compute_input =
ck::type_convert<ComputeDataType>(arg.b_ns_ks_(n0, n1, k0, k1));
AccDataType v_a; AccDataType v_a;
AccDataType v_b; AccDataType v_b;
arg.a_element_op_( arg.a_element_op_(v_a, ck::type_convert<AccDataType>(v_a_compute_input));
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1))); arg.b_element_op_(v_b, ck::type_convert<AccDataType>(v_b_compute_input));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
} }
arg.c_ms_ns_(m0, m1, n0, n1) = v_acc; arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert<CDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
......
...@@ -3,12 +3,23 @@ ...@@ -3,12 +3,23 @@
#pragma once #pragma once
#include <iostream> #include <cmath>
#include <cstdlib>
#include <numeric>
#include <type_traits> #include <type_traits>
#include <sstream> #include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -22,6 +33,7 @@ namespace host { ...@@ -22,6 +33,7 @@ namespace host {
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout // Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order // as long as dimensions in tensor descriptor is in GNCHW order
// //
// @tparam NDimSpatial Number of spatial dimensions.
// @tparam InDataType Input tensor data type. // @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type. // @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type. // @tparam OutDataType Output tensor data type.
...@@ -29,7 +41,9 @@ namespace host { ...@@ -29,7 +41,9 @@ namespace host {
// operation. // operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise // @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation. // operation.
// @tparam NDimSpatial Number of spatial dimensions. // @tparam NumAElementwiseTensor Number of A elementwise tensors.
// @tparam NumBElementwiseTensor Number of B elementwise tensors.
// @tparam NumDElementwiseTensor Number of D elementwise tensors.
// //
// input descriptor in [G, N, C, Do, Ho, Wo] order // input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order // weight descriptor in [G, K, C, Z, Y, X] order
...@@ -42,25 +56,35 @@ template <ck::index_t NDimSpatial, ...@@ -42,25 +56,35 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumAElementwiseTensor = 0,
ck::index_t NumBElementwiseTensor = 0,
ck::index_t NumDElementwiseTensor = 0,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator struct ReferenceConvFwd : public device::BaseOperator
{ {
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(const Tensor<InDataType>& input, Argument(
const Tensor<WeiDataType>& weight, const Tensor<InDataType>& input,
Tensor<OutDataType>& output, const Tensor<WeiDataType>& weight,
std::vector<ck::index_t> conv_filter_strides, Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_left_pads,
InElementwiseOperation in_element_op, std::vector<ck::index_t> input_right_pads,
WeiElementwiseOperation wei_element_op, InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op) WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors,
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors,
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors)
: input_{input}, : input_{input},
weight_{weight}, weight_{weight},
output_{output}, output_{output},
elementwise_a_tensors_{elementwise_a_tensors},
elementwise_b_tensors_{elementwise_b_tensors},
elementwise_d_tensors_{elementwise_d_tensors},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -75,6 +99,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -75,6 +99,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<WeiDataType>& weight_; const Tensor<WeiDataType>& weight_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors_;
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<index_t> in_left_pads_;
...@@ -114,23 +142,43 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -114,23 +142,43 @@ struct ReferenceConvFwd : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_in; InDataType v_in;
float v_wei; WeiDataType v_wei;
arg.in_element_op_( ExecuteElementwiseOp(arg.in_element_op_,
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi))); arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
arg.wei_element_op_( v_in,
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x))); arg.input_(g, n, c, wi),
g,
v_acc += v_in * v_wei; n,
c,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, x),
g,
k,
c,
x);
v_acc +=
ck::type_convert<float>(v_in) * ck::type_convert<float>(v_wei);
} }
} }
} }
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
OutDataType v_out; OutDataType& v_out = arg.output_(g, n, k, wo);
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc)); ExecuteElementwiseOp(arg.out_element_op_,
arg.output_(g, n, k, wo) = v_out; arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_out,
v_acc_converted,
g,
n,
k,
wo);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
...@@ -167,24 +215,47 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -167,24 +215,47 @@ struct ReferenceConvFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
float v_in; InDataType v_in;
float v_wei; WeiDataType v_wei;
arg.in_element_op_( ExecuteElementwiseOp(arg.in_element_op_,
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi))); arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
arg.wei_element_op_( v_in,
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x))); arg.input_(g, n, c, hi, wi),
g,
v_acc += v_in * v_wei; n,
c,
hi,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, y, x),
g,
k,
c,
y,
x);
v_acc += ck::type_convert<float>(v_in) *
ck::type_convert<float>(v_wei);
} }
} }
} }
} }
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
OutDataType v_out; OutDataType& v_out = arg.output_(g, n, k, ho, wo);
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc)); ExecuteElementwiseOp(arg.out_element_op_,
arg.output_(g, n, k, ho, wo) = v_out; arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_out,
v_acc_converted,
g,
n,
k,
ho,
wo);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
...@@ -231,27 +302,51 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -231,27 +302,51 @@ struct ReferenceConvFwd : public device::BaseOperator
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5]) arg.input_.GetLengths()[5])
{ {
float v_in; InDataType v_in;
float v_wei; WeiDataType v_wei;
arg.in_element_op_(v_in, ExecuteElementwiseOp(arg.in_element_op_,
ck::type_convert<float>( arg.elementwise_a_tensors_,
arg.input_(g, n, c, di, hi, wi))); Number<NumAElementwiseTensor>{},
v_in,
arg.wei_element_op_( arg.input_(g, n, c, di, hi, wi),
v_wei, g,
ck::type_convert<float>(arg.weight_(g, k, c, z, y, x))); n,
c,
v_acc += v_in * v_wei; di,
hi,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, z, y, x),
g,
k,
c,
z,
y,
x);
v_acc += ck::type_convert<float>(v_in) *
ck::type_convert<float>(v_wei);
} }
} }
} }
} }
} }
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
OutDataType v_out; OutDataType& v_out = arg.output_(g, n, k, d_o, ho, wo);
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc)); ExecuteElementwiseOp(arg.out_element_op_,
arg.output_(g, n, k, d_o, ho, wo) = v_out; arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_out,
v_acc_converted,
g,
n,
k,
d_o,
ho,
wo);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
...@@ -274,6 +369,36 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -274,6 +369,36 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
}; };
template <typename... Args,
typename ElementwiseOp,
typename ElementwiseTensor,
typename NumTensor,
typename T>
static void ExecuteElementwiseOp(ElementwiseOp& elementwise_op,
ElementwiseTensor& elementwise_tensors,
NumTensor,
T& y,
const T& x,
Args... dims)
{
if constexpr(NumTensor::value == 0)
{
elementwise_op(y, x);
}
else if constexpr(NumTensor::value == 1)
{
elementwise_op(y, x, elementwise_tensors[0](dims...));
}
else if constexpr(NumTensor::value == 2)
{
elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
}
}
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
...@@ -285,16 +410,20 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -285,16 +410,20 @@ struct ReferenceConvFwd : public device::BaseOperator
return NDimSpatial >= 1 && NDimSpatial <= 3; return NDimSpatial >= 1 && NDimSpatial <= 3;
} }
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(
const Tensor<WeiDataType>& weight, const Tensor<InDataType>& input,
Tensor<OutDataType>& output, const Tensor<WeiDataType>& weight,
std::vector<ck::index_t> conv_filter_strides, Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_left_pads,
InElementwiseOperation in_element_op, std::vector<ck::index_t> input_right_pads,
WeiElementwiseOperation wei_element_op, InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op) WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors = {},
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors = {},
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors = {})
{ {
return Argument{input, return Argument{input,
weight, weight,
...@@ -305,7 +434,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -305,7 +434,10 @@ struct ReferenceConvFwd : public device::BaseOperator
input_right_pads, input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op}; out_element_op,
elementwise_a_tensors,
elementwise_b_tensors,
elementwise_d_tensors};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
typename DXDataType,
typename ComputeDataType>
struct ReferenceGroupnormBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<DYDataType>& dy_nhwgc,
const Tensor<XDataType>& x_nhwgc,
const Tensor<GammaDataType>& gamma_gc,
const Tensor<MeanInvStdDataType>& mean_ng,
const Tensor<MeanInvStdDataType>& inv_std_ng,
Tensor<DGammaDataType>& dgamma_gc,
Tensor<DBetaDataType>& dbeta_gc,
Tensor<DXDataType>& dx_nhwgc,
const std::vector<index_t> lengths)
: dy_nhwgc_(dy_nhwgc),
x_nhwgc_(x_nhwgc),
gamma_gc_(gamma_gc),
mean_ng_(mean_ng),
inv_std_ng_(inv_std_ng),
dgamma_gc_(dgamma_gc),
dbeta_gc_(dbeta_gc),
dx_nhwgc_(dx_nhwgc),
lengths_(lengths)
{
}
const Tensor<DYDataType>& dy_nhwgc_;
const Tensor<XDataType>& x_nhwgc_;
const Tensor<GammaDataType>& gamma_gc_;
const Tensor<MeanInvStdDataType>& mean_ng_;
const Tensor<MeanInvStdDataType>& inv_std_ng_;
Tensor<DGammaDataType>& dgamma_gc_;
Tensor<DBetaDataType>& dbeta_gc_;
Tensor<DXDataType>& dx_nhwgc_;
std::vector<index_t> lengths_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
int N = arg.lengths_[0];
int H = arg.lengths_[1];
int W = arg.lengths_[2];
int G = arg.lengths_[3];
int C = arg.lengths_[4];
// Calculate dgamma and dbeta
for(int g = 0; g < G; ++g)
for(int c = 0; c < C; ++c)
{
ComputeDataType dgamma = 0;
ComputeDataType dbeta = 0;
for(int n = 0; n < N; ++n)
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType mean =
ck::type_convert<ComputeDataType>(arg.mean_ng_(n, g));
ComputeDataType rstd =
ck::type_convert<ComputeDataType>(arg.inv_std_ng_(n, g));
dgamma += dy * rstd * (x - mean);
dbeta += dy;
}
arg.dgamma_gc_(g, c) = ck::type_convert<DGammaDataType>(dgamma);
arg.dbeta_gc_(g, c) = ck::type_convert<DBetaDataType>(dbeta);
}
// Calculate dx
int reduce_size = H * W * C;
for(int n = 0; n < N; ++n)
for(int g = 0; g < G; ++g)
{
ComputeDataType ds = 0;
ComputeDataType db = 0;
ComputeDataType mean = ck::type_convert<ComputeDataType>(arg.mean_ng_(n, g));
ComputeDataType rstd = ck::type_convert<ComputeDataType>(arg.inv_std_ng_(n, g));
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType gamma =
ck::type_convert<ComputeDataType>(arg.gamma_gc_(g, c));
ds += dy * gamma * x;
db += dy * gamma;
}
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType gamma =
ck::type_convert<ComputeDataType>(arg.gamma_gc_(g, c));
ComputeDataType b =
(db * mean - ds) * rstd * rstd * rstd / reduce_size;
ComputeDataType c1 = -b * mean - db * rstd / reduce_size;
arg.dx_nhwgc_(n, h, w, g, c) =
ck::type_convert<DXDataType>(dy * gamma * rstd + b * x + c1);
}
}
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<DYDataType>& dy_nhwgc,
const Tensor<XDataType>& x_nhwgc,
const Tensor<GammaDataType>& gamma_gc,
const Tensor<MeanInvStdDataType>& mean_ng,
const Tensor<MeanInvStdDataType>& inv_std_ng,
Tensor<DGammaDataType>& dgamma_gc,
Tensor<DBetaDataType>& dbeta_gc,
Tensor<DXDataType>& dx_nhwgc,
const std::vector<index_t> lengths)
{
return Argument{dy_nhwgc,
x_nhwgc,
gamma_gc,
mean_ng,
inv_std_ng,
dgamma_gc,
dbeta_gc,
dx_nhwgc,
lengths};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGroupnormBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -19,9 +19,7 @@ namespace host { ...@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for image to column. * \brief Reference implementation for image to column.
* *
* Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout. * Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C]. * Output tensor descriptor has [G * N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Output tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout. * \tparam ImageLayout Image Layout.
...@@ -95,18 +93,19 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -95,18 +93,19 @@ struct ReferenceImageToColumn : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 && if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == 2)) arg.output_.GetNumOfDimension() == 3))
{ {
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.input_.GetLengths()[0];
const index_t N = arg.input_.GetLengths()[1]; const index_t N = arg.input_.GetLengths()[1];
const index_t C = arg.input_.GetLengths()[2]; const index_t C = arg.input_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n, auto wo) { auto func = [&](auto g, auto n, auto wo) {
index_t row = n * Wo + wo; index_t row = n * Wo + wo;
index_t column = 0; index_t column = 0;
...@@ -121,15 +120,15 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -121,15 +120,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
InDataType v_in = arg.input_(0, n, c, wi); InDataType v_in = arg.input_(g, n, c, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in); arg.output_(g, row, column) = ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
} }
} }
}; };
make_ParallelTensorFunctor(func, N, Wo)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -138,7 +137,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -138,7 +137,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[0]; const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n, auto ho, auto wo) { auto func = [&](auto g, auto n, auto ho, auto wo) {
index_t row = n * Ho * Wo + ho * Wo + wo; index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; index_t column = 0;
...@@ -162,8 +161,9 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -162,8 +161,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
InDataType v_in = arg.input_(0, n, c, hi, wi); InDataType v_in = arg.input_(g, n, c, hi, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in); arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
} }
...@@ -171,7 +171,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -171,7 +171,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N, Ho, Wo)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -181,7 +181,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -181,7 +181,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[1]; const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; index_t column = 0;
...@@ -213,8 +213,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -213,8 +213,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5]) arg.input_.GetLengths()[5])
{ {
InDataType v_in = arg.input_(0, n, c, di, hi, wi); InDataType v_in = arg.input_(g, n, c, di, hi, wi);
arg.output_(row, column) = arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in); ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
...@@ -224,7 +224,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -224,7 +224,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N, Do, Ho, Wo)( make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -267,8 +267,9 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -267,8 +267,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) && if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.output_.GetLengths()[1] == static_cast<std::size_t>(CZYX))) arg.output_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.output_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment