Commit c6891e12 authored by rocking's avatar rocking
Browse files

Merge branch 'develop' into standalone-layernorm

parents f591ad27 8e374781
...@@ -64,8 +64,16 @@ template < ...@@ -64,8 +64,16 @@ template <
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>, is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false> bool> = false>
struct DeviceGemmDl struct DeviceGemmDl : public DeviceGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -534,8 +542,7 @@ struct DeviceGemmDl ...@@ -534,8 +542,7 @@ struct DeviceGemmDl
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
......
...@@ -16,12 +16,20 @@ namespace device { ...@@ -16,12 +16,20 @@ namespace device {
// output : E[M, N] // output : E[M, N]
// C = a_op(A) * b_op(B) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
template <ck::index_t NumDTensor, template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceGemmMultipleD : public BaseOperator struct DeviceGemmMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -41,14 +49,26 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -41,14 +49,26 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <ck::index_t NumDTensor, template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
using DeviceGemmMultipleDPtr = std::unique_ptr<DeviceGemmMultipleD<NumDTensor, using DeviceGemmMultipleDPtr = std::unique_ptr<DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>>; CDEElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -96,7 +96,7 @@ namespace device { ...@@ -96,7 +96,7 @@ namespace device {
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CDELayout, typename DELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -137,7 +137,13 @@ template <typename ALayout, ...@@ -137,7 +137,13 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType::Size(), struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
...@@ -360,12 +366,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -360,12 +366,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CDELayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1)); make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CDELayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE)); make_tuple(I1, StrideE));
......
...@@ -2,13 +2,16 @@ ...@@ -2,13 +2,16 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include "device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// FIXME: DeviceGemmReduce type need to well define the problem
template <ck::index_t NumDTensor, ck::index_t NumReduce> template <ck::index_t NumDTensor, ck::index_t NumReduce>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -11,7 +12,13 @@ namespace ck { ...@@ -11,7 +12,13 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemmSplitK : public BaseOperator struct DeviceGemmSplitK : public BaseOperator
...@@ -33,11 +40,24 @@ struct DeviceGemmSplitK : public BaseOperator ...@@ -33,11 +40,24 @@ struct DeviceGemmSplitK : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGemmSplitKPtr = std::unique_ptr< using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -57,8 +57,15 @@ template <typename ADataType, ...@@ -57,8 +57,15 @@ template <typename ADataType,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1> ck::index_t NumPrefetch = 1>
struct DeviceGemmXdl struct DeviceGemmXdl : public DeviceGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -487,8 +494,7 @@ struct DeviceGemmXdl ...@@ -487,8 +494,7 @@ struct DeviceGemmXdl
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
......
...@@ -65,8 +65,15 @@ template <typename ALayout, ...@@ -65,8 +65,15 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemm_Xdl_CShuffle struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
using DeviceOp = DeviceGemm_Xdl_CShuffle; using DeviceOp = DeviceGemm_Xdl_CShuffle;
...@@ -622,8 +629,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -622,8 +629,7 @@ struct DeviceGemm_Xdl_CShuffle
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
......
...@@ -56,8 +56,15 @@ template <typename ADataType, ...@@ -56,8 +56,15 @@ template <typename ADataType,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceGemmXdlSplitK struct DeviceGemmXdlSplitK : public DeviceGemmSplitK<ALayout,
: public DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
...@@ -58,8 +58,15 @@ template <typename ADataType, ...@@ -58,8 +58,15 @@ template <typename ADataType,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL> index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
struct DeviceGemmXdlSplitKCShuffle struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
: public DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
...@@ -46,13 +46,22 @@ __global__ void ...@@ -46,13 +46,22 @@ __global__ void
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t group_id = 0; index_t left = 0;
for(index_t i = 0; i < group_count; i++) index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
{ {
group_id = if(block_id < gemm_desc_ptr[group_id].BlockStart_)
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_) {
? i right = group_id;
: group_id; }
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
} }
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct DeviceNormalization : public BaseOperator
{
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the normalization operation is applied
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
virtual index_t GetNumReduceDim() const = 0;
};
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
...@@ -33,8 +34,15 @@ template <typename InDataType, ...@@ -33,8 +34,15 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceSoftmax : public BaseOperator struct DeviceSoftmax : public DeviceNormalization
{ {
static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim;
virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock // Used for freeloading of some handy functions from DeviceReduceMultiBlock
...@@ -61,7 +69,7 @@ struct DeviceSoftmax : public BaseOperator ...@@ -61,7 +69,7 @@ struct DeviceSoftmax : public BaseOperator
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduce = GridwiseSoftmax_mk_to_mk<InDataType, using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
...@@ -72,7 +80,22 @@ struct DeviceSoftmax : public BaseOperator ...@@ -72,7 +80,22 @@ struct DeviceSoftmax : public BaseOperator
KThreadSliceSize, KThreadSliceSize,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
OutDstVectorSize>; OutDstVectorSize,
false>;
using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument struct Argument : public Reduction::Argument
{ {
...@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator ...@@ -121,8 +144,19 @@ struct DeviceSoftmax : public BaseOperator
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto kernel_main = bool sweep_once =
kernel_softmax<GridwiseReduce, InDataType, OutDataType, AccDataType, GridDesc_M_K>; in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>
: kernel_softmax<GridwiseSoftmaxGeneric,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0; float avg_time = 0;
...@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator ...@@ -167,24 +201,34 @@ struct DeviceSoftmax : public BaseOperator
return true; return true;
}; };
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths, std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
AccDataType alpha, const void* alpha,
AccDataType beta, const void* beta,
const void* in_dev, const void* in_dev,
void* out_dev) void* out_dev) override
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
reduceDims, reduceDims,
alpha, *static_cast<const AccDataType*>(alpha),
beta, *static_cast<const AccDataType*>(beta),
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev)); static_cast<OutDataType*>(out_dev));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }; std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
......
...@@ -11,8 +11,8 @@ namespace element_wise { ...@@ -11,8 +11,8 @@ namespace element_wise {
struct Add struct Add
{ {
template <typename T> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
...@@ -28,6 +28,13 @@ struct Add ...@@ -28,6 +28,13 @@ struct Add
y = x0 + x1; y = x0 + x1;
}; };
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
};
// Question: should half_t be supported ? // Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
......
...@@ -49,7 +49,8 @@ template <typename InDataType, ...@@ -49,7 +49,8 @@ template <typename InDataType,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize,
bool SweepOnce>
struct GridwiseSoftmax_mk_to_mk struct GridwiseSoftmax_mk_to_mk
{ {
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
...@@ -75,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -75,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false>; // PropagateNan
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false>; // PropagateNan
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -105,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -105,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_value_global) OutDataType* const __restrict__ p_out_value_global)
{ {
if constexpr(SweepOnce)
{
num_k_block_tile_iteration = 1;
}
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -149,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -149,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
//
// NOTE: reset coordinate after every step because the same threadwise copy will sweep
// through global memory 3 times back and forth
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
...@@ -158,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -158,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
1, 1,
false>( true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(
in_grid_desc_m_k, in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock + block_local_id * reduceSizePerBlock +
...@@ -198,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -198,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
constexpr auto in_thread_copy_fwd_step = make_multi_index(0, K_BlockTileSize); constexpr auto in_thread_copy_fwd_step =
constexpr auto in_thread_copy_bwd_step = make_multi_index(0, -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto in_thread_copy_bwd_step =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
/// ///
/// max(x) /// max(x)
/// ///
const auto in_global_val_buf_oob_non_zero = make_dynamic_buffer<AddressSpaceEnum::Global>( using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
p_in_value_global, AccDataType,
in_grid_desc_m_k.GetElementSpaceSize(), BlockSize,
reduce::Max::template GetIdentityValue<InDataType>()); ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize());
index_t reducedTiles = 0; index_t reducedTiles = 0;
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_non_zero, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
...@@ -232,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -232,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk
/// ///
/// sum(exp(x - max(x))) /// sum(exp(x - max(x)))
/// ///
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const auto in_global_val_buf_oob_nan =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
NumericLimits<InDataType>::QuietNaN());
using BlockwiseSumReduce = PartitionedBlockwiseReduction< using BlockwiseSumReduce = PartitionedBlockwiseReduction<
AccDataType, AccDataType,
BlockSize, BlockSize,
...@@ -271,23 +277,26 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -271,23 +277,26 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles = 0; reducedTiles = 0;
do do
{
if constexpr(!SweepOnce)
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
}
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_buf(Number<offset>{}) = out_thread_buf(Number<offset>{}) =
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)); math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
}); });
}); });
ThreadwiseSumReduce::Reduce(in_thread_buf, accu_value_buf); ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
...@@ -308,12 +317,15 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -308,12 +317,15 @@ struct GridwiseSoftmax_mk_to_mk
if(float_equal_zero{}(beta)) if(float_equal_zero{}(beta))
{ {
do do
{
if constexpr(!SweepOnce)
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) // out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
...@@ -340,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -340,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk
} }
else else
{ {
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_prior_dst_buf;
do do
{
if constexpr(!SweepOnce)
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
}
threadwise_dst_load.Run(out_grid_desc_m_k, threadwise_dst_load.Run(out_grid_desc_m_k,
out_global_val_buf, out_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
out_thread_buf); in_prior_dst_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
...@@ -360,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk ...@@ -360,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk
out_thread_buf(Number<offset>{}) = out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) / alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM) + accu_value_buf(iM) +
beta * out_thread_buf(Number<offset>{}); beta * in_prior_dst_buf(Number<offset>{});
}); });
}); });
......
...@@ -30,6 +30,8 @@ struct ThreadwiseReduction ...@@ -30,6 +30,8 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Op = OpReduce;
template <typename SrcBufferType, typename DstBufferType> template <typename SrcBufferType, typename DstBufferType>
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
{ {
......
...@@ -236,9 +236,14 @@ template <typename SrcData, ...@@ -236,9 +236,14 @@ template <typename SrcData,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
if constexpr(InvalidElementAsNaN)
{
dst_buf(Number<dst_offset>{}) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) = dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]); type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
}); });
if constexpr(idx_1d.value != num_access - 1) if constexpr(idx_1d.value != num_access - 1)
......
...@@ -12,21 +12,27 @@ template <typename T, typename Enable = void> ...@@ -12,21 +12,27 @@ template <typename T, typename Enable = void>
struct PrintAsType; struct PrintAsType;
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::value> struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
}; };
template <> template <>
struct PrintAsType<ck::half_t, void> struct PrintAsType<ck::half_t, void>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const ck::half_t& p)
{
printf("%.3f ", static_cast<type>(p));
}
}; };
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value> struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
{ {
using type = int; using type = int;
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
}; };
} // namespace detail } // namespace detail
...@@ -41,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value ...@@ -41,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value
template <typename T, index_t element_stride = 1, index_t row_bytes = 128> template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements) __device__ void print_shared(T const* p_shared, index_t num_elements)
{ {
using PrintType = typename detail::PrintAsType<T>::type;
constexpr index_t row_elements = row_bytes / sizeof(T); constexpr index_t row_elements = row_bytes / sizeof(T);
static_assert((element_stride >= 1 && element_stride <= row_elements), static_assert((element_stride >= 1 && element_stride <= row_elements),
"element_stride should between [1, row_elements]"); "element_stride should between [1, row_elements]");
...@@ -63,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) ...@@ -63,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
printf("elem %5d: ", i); printf("elem %5d: ", i);
for(index_t j = 0; j < row_elements; j += element_stride) for(index_t j = 0; j < row_elements; j += element_stride)
{ {
printf("%.0f ", static_cast<PrintType>(p_shared[i + j])); detail::PrintAsType<T>::Print(p_shared[i + j]);
} }
printf("\n"); printf("\n");
......
...@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
template <typename T> template <typename T>
__device__ T exp(T x); __device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <> template <>
__device__ float exp<float>(float x) __device__ float exp<float>(float x)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment