Unverified Commit 15c89e81 authored by Anthony Chang's avatar Anthony Chang Committed by GitHub
Browse files

Standalone softmax kernel (#284)

* initial stub for standalone softmax

* start device_softmax_mk_to_mk as a wrapper to device_reduce_mk_to_m

* host softmax validates

* compiles; to implement beta scaling

* use NaN trick to efficiently ignore OOB values during sum of exponentials

* freeload device_reduce's utility functions

* clean up interface

* adding prior value (beta scaling)

* remove restriction related to perf considerations

* apply clang-format

* clean; disable diagnostics

* resolve conflicts

* add exp wrapper

* honor HostTensorDesc interface; allow implicit cast from different vector<T> type

* test softmax for fp16/fp32

* update readme

* amend commit NaN trick

* remove redundant param added during development

* format

* replace ScalarDataType with AccDataType

* separate out test programs by precision type

* move softmax sample code to its own folder

* format

* keep up with recent changes in reduction API

* remove extra header
parent be60d60d
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
# -D <xxx> : input 4-d tensor lengths # -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes) # -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) #arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg2: time kernel (0=no, 1=yes) #arg2: time kernel (0=no, 1=yes)
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 ./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
``` ```
Result Result
``` ```
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 ./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time Warm up 1 time
Start running 10 times... Start running 10 times...
Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
...@@ -24,19 +24,18 @@ Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr ...@@ -24,19 +24,18 @@ Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
```bash ```bash
#arg1: verification (0=no, 1=yes( #arg1: verification (0=no, 1=yes(
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) #arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg3: time kernel (0=no, 1=yes) #arg3: time kernel (0=no, 1=yes)
./bin/example_reduce_blockwise_two_call 1 2 1 ./bin/example_reduce_blockwise_two_call 1 2 1
```
Result Result
``` ```
./bin/example_reduce_blockwise_two_call 1 2 1 ./bin/example_reduce_blockwise_two_call 1 2 1
launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1} launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time Warm up 1 time
Start running 10 times... Start running 10 times...
launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1} launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time Warm up 1 time
Start running 10 times... Start running 10 times...
Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1>
``` ```
add_example_executable(example_softmax_blockwise softmax_blockwise.cpp)
\ No newline at end of file
# Instructions for ```example_softmax_blockwise```
## Run ```example_softmax_blockwise```
```bash
# -D <xxx> : input 3-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg2: time kernel (0=no, 1=yes)
example_softmax_blockwise -D 4,128,2048 -v 1 1 1
```
Result
```
launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.0242877 ms, 259.039 GB/s, DeviceReduceSoftmax<256,M_C8_S1,K_C32_S8,InSrcVectorDim_1_InSrcVectorSize_8_OutDstVectorSize_8>
```
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "check_err.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_base.hpp"
#include "device_softmax.hpp"
#include "host_common_util.hpp"
#include "reference_softmax.hpp"
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
using namespace ck;
using namespace ck::tensor_operation::device;
using InDataType = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;
constexpr int Rank = 3;
constexpr int NumReduceDim = 1;
using DeviceInstance = DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector
8>; // OutScalarPerVector
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}};
class SimpleAppArgs
{
private:
int option_index = 0;
public:
std::vector<size_t> inLengths = {8, 128, 2048};
std::vector<AccDataType> scales = {2.0f, 2.0f};
bool do_verification = true;
int init_method = 2;
bool time_kernel = true;
public:
void show_usage(const char* cmd)
{
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
<< std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
"comparing with the host-based reduction"
<< std::endl;
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl;
};
int processArgs(int argc, char* argv[])
{
using ck::host_common::getTypeValuesFromString;
int ch;
while(1)
{
ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index);
if(ch == -1)
break;
switch(ch)
{
case 'D':
if(!optarg)
throw std::runtime_error("Invalid option format!");
inLengths = getTypeValuesFromString<size_t>(optarg);
break;
case 'v':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_verification = static_cast<bool>(std::atoi(optarg));
break;
case '?':
if(std::string(long_options[option_index].name) == "help")
{
show_usage(argv[0]);
return (-1);
};
break;
default: show_usage(argv[0]); return (-1);
};
};
if(optind + 2 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
if(scales.empty())
{
scales.push_back(1.0f);
scales.push_back(0.0f);
};
return (0);
};
};
int main(int argc, char* argv[])
{
// Example: batched gemm C[G, M, N] applies max/sum reduction along N internally
const std::vector<int> invariantDims{0, 1};
const std::vector<int> reduceDims{2};
SimpleAppArgs args;
if(argc > 1)
{
if(args.processArgs(argc, argv) < 0)
return (-1);
};
Tensor<InDataType> in(args.inLengths);
Tensor<OutDataType> out_ref(args.inLengths);
Tensor<OutDataType> out(args.inLengths);
auto inStrides = in.mDesc.GetStrides();
auto outStrides = out.mDesc.GetStrides();
AccDataType alpha = args.scales[0];
AccDataType beta = args.scales[1];
std::size_t num_thread = 1;
if(args.do_verification)
{
switch(args.init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}, num_thread);
break;
case 2:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}, num_thread);
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-5.0, 5.0}, num_thread);
}
if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
out.mData[i] = out_ref.mData[i];
};
// std::cout << "beta = " << beta << std::endl;
// LogRangeAsType<float>(std::cout << "tensor in: " , in.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "tensor prior out: " , out.mData, ",") << std::endl;
// these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
in_dev.ToDevice(in.mData.data());
if(beta != 0.0f)
out_dev.ToDevice(out.mData.data());
if(args.do_verification)
{
using ReferenceInstance =
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
ReferenceInstance ref;
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, Rank, reduceDims);
auto invoker = ref.MakeInvoker();
invoker.Run(ref_arg);
// LogRangeAsType<float>(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl;
};
std::vector<ck::index_t> i_inLengths;
std::vector<ck::index_t> i_inStrides;
i_inLengths.assign(args.inLengths.begin(), args.inLengths.end());
i_inStrides.assign(inStrides.begin(), inStrides.end());
auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
i_inStrides,
reduceDims,
alpha,
beta,
in_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer());
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
std::cout
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<< std::endl;
return 1;
};
std::string instance_name = device_instance.GetTypeString();
auto invoker_ptr = device_instance.MakeInvokerPointer();
bool pass = true;
if(args.do_verification)
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
out_dev.FromDevice(out.mData.data());
// LogRangeAsType<float>(std::cout << "tensor out: " , out.mData, ",") << std::endl;
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
};
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel});
std::size_t num_bytes =
in.mDesc.GetElementSize() * sizeof(InDataType) +
(beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name
<< std::endl;
return (pass ? 0 : 1);
}
...@@ -56,3 +56,4 @@ add_subdirectory(19_binary_elementwise) ...@@ -56,3 +56,4 @@ add_subdirectory(19_binary_elementwise)
add_subdirectory(20_convnd_bwd_weight_xdl) add_subdirectory(20_convnd_bwd_weight_xdl)
add_subdirectory(21_gemm_layernorm) add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm) add_subdirectory(22_cgemm)
add_subdirectory(23_softmax)
...@@ -45,7 +45,9 @@ template <typename AccDataType, ...@@ -45,7 +45,9 @@ template <typename AccDataType,
typename ThreadClusterLengths_M_K, typename ThreadClusterLengths_M_K,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename OpReduce, typename OpReduce,
bool PropagateNan> bool PropagateNan,
typename Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
struct PartitionedBlockwiseReduction struct PartitionedBlockwiseReduction
{ {
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
...@@ -62,8 +64,6 @@ struct PartitionedBlockwiseReduction ...@@ -62,8 +64,6 @@ struct PartitionedBlockwiseReduction
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename BufferType> template <typename BufferType>
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
{ {
...@@ -113,13 +113,16 @@ struct PartitionedBlockwiseReduction ...@@ -113,13 +113,16 @@ struct PartitionedBlockwiseReduction
// 3) in_out_value/in_out_index is the input data in vgpr from each thread // 3) in_out_value/in_out_index is the input data in vgpr from each thread
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread // 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
// clang-format on // clang-format on
template <typename AccDataType, template <
typename IndexDataType, typename AccDataType,
index_t BlockSize, typename IndexDataType,
typename ThreadClusterLengths_M_K, index_t BlockSize,
typename ThreadClusterArrangeOrder, typename ThreadClusterLengths_M_K,
typename OpReduce, typename ThreadClusterArrangeOrder,
bool PropagateNan> typename OpReduce,
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
struct PartitionedBlockwiseReductionWithIndex struct PartitionedBlockwiseReductionWithIndex
{ {
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
...@@ -136,9 +139,6 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -136,9 +139,6 @@ struct PartitionedBlockwiseReductionWithIndex
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
// This interface accumulates on both data values and indices // This interface accumulates on both data values and indices
template <typename BufferType, typename IdxBufferType> template <typename BufferType, typename IdxBufferType>
__device__ static void Reduce(BufferType& work_val_buffer, __device__ static void Reduce(BufferType& work_val_buffer,
......
...@@ -390,10 +390,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -390,10 +390,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
}; };
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override static bool IsSupportedArgument(const Argument* pArg)
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(use_multiblock) if constexpr(use_multiblock)
{ {
if(static_cast<float>(pArg->beta_) != 0.0f) if(static_cast<float>(pArg->beta_) != 0.0f)
...@@ -442,11 +440,16 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -442,11 +440,16 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
else else
{ {
// cases with very small reduce_total_length should be handled by ThreadWise kernel // cases with very small reduce_total_length should be handled by ThreadWise kernel
if(pArg->reduce_total_length / KThreadSliceSize < 2) // if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false); // return (false);
}; };
return (true); return (true);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(dynamic_cast<const Argument*>(p_arg));
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
......
#ifndef DEVICE_SOFTMAX_HPP
#define DEVICE_SOFTMAX_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "device_reduce.hpp"
#include "device_reduce_multiblock.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_softmax.hpp"
#include "gridwise_set_buffer_value.hpp"
#include "reduction_operator.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceSoftmax : public BaseOperator
{
using PassThrough = tensor_operation::element_wise::PassThrough;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using Reduction = DeviceReduceMultiBlock<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
reduce::Add,
PassThrough, // InElementwiseOperation
PassThrough, // AccElementwiseOperation
InMemoryDataOperationEnum::Set,
false, // PropagateNan
false, // OutputIndex
false, // HaveIndexInputIfOutputIndex
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
1>; // OutDstVectorSize
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
struct Argument : public Reduction::Argument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims,
AccDataType alpha,
AccDataType beta,
const InDataType* in_dev,
OutDataType* out_dev)
: Reduction::Argument(inLengths,
inStrides,
{},
{},
reduceDims,
0.0f, // alpha
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
PassThrough{},
PassThrough{}),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
}
AccDataType alpha_;
AccDataType beta_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto kernel_main =
kernel_softmax<GridwiseReduce, InDataType, OutDataType, AccDataType, GridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_))
{
return false;
}
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
{
return false;
}
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
AccDataType alpha,
AccDataType beta,
const void* in_dev,
void* out_dev)
{
return std::make_unique<Argument>(inLengths,
inStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); };
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceSoftmax<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif // DEVICE_SOFTMAX_HPP
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GRIDWISE_SOFTMAX_HPP
#define GRIDWISE_SOFTMAX_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename GridDesc_M_K>
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k,
const GridDesc_M_K out_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_value_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m_k,
block_group_size,
num_k_block_tile_iteration,
alpha,
p_in_value_global,
beta,
p_out_value_global);
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename GridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseSoftmax_mk_to_mk
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(KThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false>; // PropagateNan
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false>; // PropagateNan
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k,
const GridDesc_M_K& out_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_value_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
out_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
});
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
AccDataType,
GridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
out_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(thread_buffer_desc),
GridDesc_M_K,
PassThroughOp,
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
constexpr auto in_thread_copy_fwd_step = make_multi_index(0, K_BlockTileSize);
constexpr auto in_thread_copy_bwd_step = make_multi_index(0, -K_BlockTileSize);
///
/// max(x)
///
const auto in_global_val_buf_oob_non_zero = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
reduce::Max::template GetIdentityValue<InDataType>());
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_non_zero,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); });
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
///
/// sum(exp(x - max(x)))
///
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const auto in_global_val_buf_oob_nan =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
NumericLimits<InDataType>::QuietNaN());
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
// do element-wise pre-reduction operation
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_buf(Number<offset>{}) =
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
});
});
ThreadwiseSumReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
// block_sync_lds();
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
///
/// softmax
///
reducedTiles = 0;
if(float_equal_zero{}(beta))
{
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM);
});
});
threadwise_dst_store.Run(thread_buffer_desc,
make_tuple(I0, I0),
out_thread_buf,
out_grid_desc_m_k,
out_global_val_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
}
else
{
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_val_buf_oob_nan,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
threadwise_dst_load.Run(out_grid_desc_m_k,
out_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
out_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM) +
beta * out_thread_buf(Number<offset>{});
});
});
threadwise_dst_store.Run(thread_buffer_desc,
make_tuple(I0, I0),
out_thread_buf,
out_grid_desc_m_k,
out_global_val_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
}
}
};
} // namespace ck
#endif // GRIDWISE_SOFTMAX_HPP
...@@ -39,7 +39,9 @@ template <typename AccDataType, ...@@ -39,7 +39,9 @@ template <typename AccDataType,
typename SrcThreadDesc_M_K, typename SrcThreadDesc_M_K,
typename DstThreadDesc_M, typename DstThreadDesc_M,
typename OpReduce, typename OpReduce,
bool PropagateNan> bool PropagateNan,
typename Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
struct ThreadwiseReduction struct ThreadwiseReduction
{ {
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
...@@ -51,8 +53,6 @@ struct ThreadwiseReduction ...@@ -51,8 +53,6 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename SrcBufferType, typename DstBufferType> template <typename SrcBufferType, typename DstBufferType>
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
{ {
...@@ -73,12 +73,15 @@ struct ThreadwiseReduction ...@@ -73,12 +73,15 @@ struct ThreadwiseReduction
// 2) DstDesc is known at compile-time // 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer // 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer // 4) DstBuffer is static buffer
template <typename AccDataType, template <
typename IndexDataType, typename AccDataType,
typename SrcThreadDesc_M_K, typename IndexDataType,
typename DstThreadDesc_M, typename SrcThreadDesc_M_K,
typename OpReduce, typename DstThreadDesc_M,
bool PropagateNan> typename OpReduce,
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
struct ThreadwiseReductionWithIndex struct ThreadwiseReductionWithIndex
{ {
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
...@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex ...@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
template <typename SrcValueBufferType, template <typename SrcValueBufferType,
typename SrcIndexBufferType, typename SrcIndexBufferType,
typename DstValueBufferType, typename DstValueBufferType,
......
...@@ -1001,6 +1001,11 @@ struct NumericLimits ...@@ -1001,6 +1001,11 @@ struct NumericLimits
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); } __host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); } __host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
__host__ __device__ static constexpr T QuietNaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
}; };
template <> template <>
...@@ -1009,12 +1014,15 @@ struct NumericLimits<half_t> ...@@ -1009,12 +1014,15 @@ struct NumericLimits<half_t>
static constexpr unsigned short binary_min = 0x0400; static constexpr unsigned short binary_min = 0x0400;
static constexpr unsigned short binary_max = 0x7BFF; static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF; static constexpr unsigned short binary_lowest = 0xFBFF;
static constexpr unsigned short binary_qnan = 0x7FFF;
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); } __host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); } __host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); } __host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
}; };
} // namespace ck } // namespace ck
...@@ -142,6 +142,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -142,6 +142,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return min(x, min(ys...)); return min(x, min(ys...));
} }
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{ {
......
...@@ -35,9 +35,27 @@ ...@@ -35,9 +35,27 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanIgnore
{
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{
if(!isnan(currVal))
{
ReduceOperation{}(accuVal, currVal);
}
};
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType> template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck; struct AccumulateWithNanCheck;
// Does not check for NaN; does not guarantee NaNs be propagated to result
// e.g., given that max(a, b) = a > b ? a : b
// then max(NaN, 1) returns 1
// max(1, NaN) returns NaN
// since any comparison involving NaNs returns false
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{ {
...@@ -48,6 +66,7 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> ...@@ -48,6 +66,7 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
}; };
}; };
// Check for NaN; guarantees NaNs be propagated to result
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{ {
......
...@@ -107,6 +107,11 @@ struct HostTensorDescriptor ...@@ -107,6 +107,11 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
} }
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
{
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
private: private:
...@@ -212,6 +217,54 @@ struct Tensor ...@@ -212,6 +217,54 @@ struct Tensor
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
{
if(rank == mDesc.GetNumOfDimension())
{
f(*this, idx);
return;
}
// else
for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
{
idx[rank] = i;
ForEach_impl(std::forward<F>(f), idx, rank + 1);
}
}
template <typename F>
void ForEach(F&& f)
{
std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
ForEach_impl(std::forward<F>(f), idx, size_t(0));
}
template <typename F>
void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const
{
if(rank == mDesc.GetNumOfDimension())
{
f(*this, idx);
return;
}
// else
for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
{
idx[rank] = i;
ForEach_impl(std::forward<const F>(f), idx, rank + 1);
}
}
template <typename F>
void ForEach(const F&& f) const
{
std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
ForEach_impl(std::forward<const F>(f), idx, size_t(0));
}
template <typename G> template <typename G>
void GenerateTensorValue(G g, std::size_t num_thread = 1) void GenerateTensorValue(G g, std::size_t num_thread = 1)
{ {
...@@ -272,6 +325,16 @@ struct Tensor ...@@ -272,6 +325,16 @@ struct Tensor
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; return mData[mDesc.GetOffsetFromMultiIndex(is...)];
} }
T& operator()(std::vector<std::size_t> idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
const T& operator()(std::vector<std::size_t> idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
typename std::vector<T>::iterator begin() { return mData.begin(); } typename std::vector<T>::iterator begin() { return mData.begin(); }
typename std::vector<T>::iterator end() { return mData.end(); } typename std::vector<T>::iterator end() { return mData.end(); }
...@@ -285,7 +348,8 @@ struct Tensor ...@@ -285,7 +348,8 @@ struct Tensor
}; };
template <typename X> template <typename X>
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(lens) HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens)
: mLens(lens.begin(), lens.end())
{ {
this->CalculateStrides(); this->CalculateStrides();
} }
...@@ -293,7 +357,7 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(l ...@@ -293,7 +357,7 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(l
template <typename X, typename Y> template <typename X, typename Y>
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens, HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
const std::vector<Y>& strides) const std::vector<Y>& strides)
: mLens(lens), mStrides(strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{ {
} }
......
...@@ -18,12 +18,12 @@ struct GeneratorTensor_0 ...@@ -18,12 +18,12 @@ struct GeneratorTensor_0
template <typename T> template <typename T>
struct GeneratorTensor_1 struct GeneratorTensor_1
{ {
int value = 1; T value = 1;
template <typename... Is> template <typename... Is>
T operator()(Is...) T operator()(Is...)
{ {
return ck::type_convert<T>(value); return value;
} }
}; };
......
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename InDataType, typename OutDataType, typename AccDataType>
struct ReferenceSoftmax : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
{
// std::cout << "debug: scalar dims: ";
for(int i = 0; i < rank; i++)
{
if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) ==
sm_reduce_dims.end())
{
sm_scalar_dims_.push_back(i);
// std::cout << i << ", ";
}
}
// std::cout << std::endl;
}
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
AccDataType alpha_;
AccDataType beta_;
index_t rank_;
std::vector<index_t> sm_reduce_dims_;
std::vector<index_t> sm_scalar_dims_; // dim after internal max/sum reduction
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
std::vector<size_t> scalar_lengths;
for(index_t dim : arg.sm_scalar_dims_)
{
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]);
}
Tensor<AccDataType> reduce_max(scalar_lengths);
reduce_max.GenerateTensorValue(
GeneratorTensor_1<AccDataType>{std::numeric_limits<AccDataType>::lowest()});
Tensor<AccDataType> reduce_sum(scalar_lengths);
reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
auto to_sm_scalar_idx = [&](auto idx) {
std::vector<size_t> sm_scalar_idx;
for(index_t dim : arg.sm_scalar_dims_)
{
sm_scalar_idx.push_back(idx[dim]);
}
return sm_scalar_idx;
};
arg.in_.ForEach([&](auto& self, auto idx) {
reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)),
static_cast<AccDataType>(self(idx)));
});
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl;
Tensor<AccDataType> in_stable(arg.in_.mDesc);
in_stable.ForEach([&](auto& self, auto idx) {
// numerator = exp(x - max(x))
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) -
reduce_max(to_sm_scalar_idx(idx)));
});
// LogRangeAsType<float>(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl;
in_stable.ForEach([&](auto& self, auto idx) {
// denominator = sum(exp(x - max(x)))
reduce_sum(to_sm_scalar_idx(idx)) += self(idx);
});
// LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",") <<
// std::endl;
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
arg.beta_ * self(idx);
});
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
// reduction along reduce dims
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",")
// << std::endl;
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims)
{
return Argument{in, out, alpha, beta, rank, sm_reduce_dims};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceSoftmax"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -65,4 +65,5 @@ add_subdirectory(reduce) ...@@ -65,4 +65,5 @@ add_subdirectory(reduce)
add_subdirectory(conv2d_bwd_weight) add_subdirectory(conv2d_bwd_weight)
add_subdirectory(convnd_bwd_data) add_subdirectory(convnd_bwd_data)
add_subdirectory(block_to_ctile_map) add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)
# DONOT add client_app, that is tested via CI independently # DONOT add client_app, that is tested via CI independently
add_custom_target(test_softmax)
add_gtest_executable(test_softmax_fp32 test_softmax_fp32.cpp)
add_gtest_executable(test_softmax_fp16 test_softmax_fp16.cpp)
target_link_libraries(test_softmax_fp32 PRIVATE host_tensor)
target_link_libraries(test_softmax_fp16 PRIVATE host_tensor)
add_dependencies(test_softmax test_softmax_fp32)
add_dependencies(test_softmax test_softmax_fp16)
\ No newline at end of file
#include "gtest/gtest.h"
#include "test_softmax_util.hpp"
template <ck::index_t N>
using I = ck::Number<N>;
template <typename Tuple>
class TestSoftmaxFP16 : public ck::TestSoftmax<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<2>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<2>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<2>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>
>;
// clang-format on
TYPED_TEST_SUITE(TestSoftmaxFP16, KernelTypes);
TYPED_TEST(TestSoftmaxFP16, Test_FP16) { this->Run(); }
#include "gtest/gtest.h"
#include "test_softmax_util.hpp"
template <ck::index_t N>
using I = ck::Number<N>;
template <typename Tuple>
class TestSoftmaxFP32 : public ck::TestSoftmax<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>
>;
// clang-format on
TYPED_TEST_SUITE(TestSoftmaxFP32, KernelTypes);
TYPED_TEST(TestSoftmaxFP32, Test_FP32) { this->Run(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment