Unverified Commit 0d911822 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Padded Generic Kernel Instance (#730)



* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting

* Move the generic kernel instance to be the first of the instance list for elementwise op of normalization

* Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax

* Add testing of GetGenericInstance() in client_example of Softmax

* Revert "Add testing of GetGenericInstance() in client_example of Softmax"

This reverts commit f629cd9a93ce38dfed4886d849f3c38d2e5379c8.

* Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax"

This reverts commit a9f0d000eb9fd240404112a526ef125429a351df.

* Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm

* Move generic kernel instance to separate tuple for elementwise op of normalization

* Remove un-used files for softmax instance

* Store generic kernel instance to separate tuple for softmax

* Add IsSupported checking for generic instance to client example of softmax

* Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization

* clang-format fix

* Remove int8 from softmax instances

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent d140bdc9
...@@ -172,18 +172,19 @@ int main() ...@@ -172,18 +172,19 @@ int main()
BLayout, BLayout,
CLayout>(); CLayout>();
const auto normalize_ptrs =
ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
CDataType,
ReduceDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>();
std::cout << "found " << gemm_reduce_ptrs.size() std::cout << "found " << gemm_reduce_ptrs.size()
<< " gemm_reduceMean_reduceSquareMean instances" << std::endl; << " gemm_reduceMean_reduceSquareMean instances" << std::endl;
using NormalizeDeviceOp = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<CDataType, ReduceDataType, ReduceDataType, GammaDataType, BetaDataType>,
ck::Tuple<LayerNormOutDataType>,
ck::tensor_operation::element_wise::Normalize,
2>;
const auto normalize_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
NormalizeDeviceOp>::GetInstances();
std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl; std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
auto f_matrix_space_size = auto f_matrix_space_size =
......
...@@ -53,12 +53,35 @@ int main(int argc, char* argv[]) ...@@ -53,12 +53,35 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * num_elements); SimpleDeviceMem in(sizeof(InDataType) * num_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_elements); SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
using DeviceOp = ck::tensor_operation::device:: using DeviceOp = ck::tensor_operation::device::DeviceSoftmax<InDataType,
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>; AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances(); DeviceOp>::GetInstances();
auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr = generic_op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
alpha,
beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
{
throw std::runtime_error(
"The generic kernel instance should be able to support any input shapes");
};
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name; std::string best_op_name;
...@@ -74,11 +97,6 @@ int main(int argc, char* argv[]) ...@@ -74,11 +97,6 @@ int main(int argc, char* argv[])
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim)
{
continue;
}
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides, in_strides,
reduce_dims, reduce_dims,
......
...@@ -72,6 +72,30 @@ int main(int argc, char* argv[]) ...@@ -72,6 +72,30 @@ int main(int argc, char* argv[])
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr =
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
nullptr,
nullptr,
Swish{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
{
throw std::runtime_error(
"The generic kernel instance should be able to support any input shapes");
};
std::string best_op_name; std::string best_op_name;
bool found = false; bool found = false;
int best_op_id = -1; int best_op_id = -1;
......
...@@ -18,7 +18,8 @@ template <typename InDataType, ...@@ -18,7 +18,8 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOp, typename InElementwiseOp,
typename AccElementwiseOp, typename AccElementwiseOp,
index_t Rank> index_t Rank,
index_t NumReduceDim>
struct DeviceSoftmax : public BaseOperator struct DeviceSoftmax : public BaseOperator
{ {
// //
...@@ -49,8 +50,6 @@ struct DeviceSoftmax : public BaseOperator ...@@ -49,8 +50,6 @@ struct DeviceSoftmax : public BaseOperator
AccElementwiseOp acc_elementwise_op) = 0; AccElementwiseOp acc_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
virtual index_t GetNumReduceDim() const = 0;
}; };
template <typename InDataType, template <typename InDataType,
...@@ -58,9 +57,15 @@ template <typename InDataType, ...@@ -58,9 +57,15 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOp, typename InElementwiseOp,
typename AccElementwiseOp, typename AccElementwiseOp,
index_t Rank> index_t Rank,
using DeviceSoftmaxPtr = std::unique_ptr< index_t NumReduceDim>
DeviceSoftmax<InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank>>; using DeviceSoftmaxPtr = std::unique_ptr<DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
InElementwiseOp,
AccElementwiseOp,
Rank,
NumReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -38,16 +38,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -38,16 +38,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDataType, OutDataType,
InElementwiseOp, InElementwiseOp,
AccElementwiseOp, AccElementwiseOp,
Rank> Rank,
NumReduceDim>
{ {
static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim;
static constexpr index_t kNumInvariantDim = Rank - NumReduceDim;
virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
...@@ -287,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -287,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
{ {
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(kNumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
return false; return false;
} }
else else
{ {
if(arg.inStrides_[kNumInvariantDim - 1] != 1 && InSrcVectorSize != 1) if(arg.inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
{ {
return false; return false;
} }
...@@ -316,7 +309,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -316,7 +309,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
} }
// To improve // To improve
if(kNumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
{ {
return false; return false;
} }
......
...@@ -5,11 +5,10 @@ ...@@ -5,11 +5,10 @@
#include <vector> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -29,12 +28,25 @@ template <typename InputType, ...@@ -29,12 +28,25 @@ template <typename InputType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename OutputType> typename OutputType>
auto get_device_normalize_from_mean_meansquare_instances() struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InputType, MeanType, MeanSquareType, GammaDataType, BetaDataType>,
ck::Tuple<OutputType>,
Normalize,
2>>
{ {
std::vector<DeviceNormalizeFromMeanMeanSquarePtr> op_ptrs; using DeviceOp = DeviceElementwise<
ck::Tuple<InputType, MeanType, MeanSquareType, GammaDataType, BetaDataType>,
ck::Tuple<OutputType>,
Normalize,
2>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same<InputType, half_t>::value && is_same<MeanType, float>::value && if constexpr(is_same<InputType, half_t>::value && is_same<MeanType, float>::value &&
is_same<MeanSquareType, float>::value && is_same<GammaDataType, half_t>::value && is_same<MeanSquareType, float>::value &&
is_same<GammaDataType, half_t>::value &&
is_same<BetaDataType, half_t>::value && is_same<OutputType, half_t>::value) is_same<BetaDataType, half_t>::value && is_same<OutputType, half_t>::value)
{ {
ck::tensor_operation::device::instance:: ck::tensor_operation::device::instance::
...@@ -42,7 +54,8 @@ auto get_device_normalize_from_mean_meansquare_instances() ...@@ -42,7 +54,8 @@ auto get_device_normalize_from_mean_meansquare_instances()
} }
return op_ptrs; return op_ptrs;
} };
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -9,34 +9,33 @@ ...@@ -9,34 +9,33 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" #include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_instances( template <typename InDataType,
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>&); typename AccDataType,
void add_device_softmax_f16_f16_rank4_instances( typename OutDataType,
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>&); index_t Rank,
index_t NumReduceDim>
void add_device_softmax_f32_f32_rank3_instances( struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftmax<InDataType,
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>&); AccDataType,
void add_device_softmax_f32_f32_rank4_instances( OutDataType,
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4>>&); PassThrough,
PassThrough,
void add_device_softmax_i8_i8_rank3_instances( Rank,
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3>>&); NumReduceDim>>
void add_device_softmax_i8_i8_rank4_instances(
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4>>&);
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>>
{ {
using DeviceOp = using DeviceOp = DeviceSoftmax<InDataType,
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>; AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -46,25 +45,49 @@ struct DeviceOperationInstanceFactory< ...@@ -46,25 +45,49 @@ struct DeviceOperationInstanceFactory<
std::is_same_v<OutDataType, F16>) std::is_same_v<OutDataType, F16>)
{ {
if constexpr(Rank == 3) if constexpr(Rank == 3)
add_device_softmax_f16_f16_rank3_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f16_f16_rank3_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f16_f16_rank3_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f16_f16_rank3_reduce3_instances(op_ptrs);
}
else if constexpr(Rank == 4) else if constexpr(Rank == 4)
add_device_softmax_f16_f16_rank4_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f16_f16_rank4_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f16_f16_rank4_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f16_f16_rank4_reduce3_instances(op_ptrs);
else if constexpr(NumReduceDim == 4)
add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs);
}
} }
else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> && else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>) std::is_same_v<OutDataType, F32>)
{ {
if constexpr(Rank == 3) if constexpr(Rank == 3)
add_device_softmax_f32_f32_rank3_instances(op_ptrs);
else if constexpr(Rank == 4)
add_device_softmax_f32_f32_rank4_instances(op_ptrs);
}
else if constexpr(std::is_same_v<InDataType, I8> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, I8>)
{ {
if constexpr(Rank == 3) if constexpr(NumReduceDim == 1)
add_device_softmax_i8_i8_rank3_instances(op_ptrs); add_device_softmax_f32_f32_rank3_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f32_f32_rank3_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f32_f32_rank3_reduce3_instances(op_ptrs);
}
else if constexpr(Rank == 4) else if constexpr(Rank == 4)
add_device_softmax_i8_i8_rank4_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f32_f32_rank4_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f32_f32_rank4_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f32_f32_rank4_reduce3_instances(op_ptrs);
else if constexpr(NumReduceDim == 4)
add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs);
}
} }
return op_ptrs; return op_ptrs;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances);
void add_device_softmax_f16_f16_rank4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce1_instances( void add_device_softmax_f16_f16_rank3_reduce1_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 1>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce2_instances( void add_device_softmax_f16_f16_rank3_reduce2_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 2>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce3_instances( void add_device_softmax_f16_f16_rank3_reduce3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 3>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank4_reduce1_instances( void add_device_softmax_f16_f16_rank4_reduce1_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 1>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank4_reduce2_instances( void add_device_softmax_f16_f16_rank4_reduce2_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 2>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank4_reduce3_instances( void add_device_softmax_f16_f16_rank4_reduce3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 3>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank4_reduce4_instances( void add_device_softmax_f16_f16_rank4_reduce4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 4>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -16,7 +16,6 @@ template <index_t Rank, index_t Reduce> ...@@ -16,7 +16,6 @@ template <index_t Rank, index_t Reduce>
using device_softmax_f16_f16_instances = std::tuple< using device_softmax_f16_f16_instances = std::tuple<
// clang-format off // clang-format off
// InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
// fallback kernel
DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>,
DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>,
...@@ -33,6 +32,13 @@ using device_softmax_f16_f16_instances = std::tuple< ...@@ -33,6 +32,13 @@ using device_softmax_f16_f16_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <index_t Rank, index_t Reduce>
using device_softmax_f16_f16_generic_instance = std::tuple<
// clang-format off
DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 64, 8, 8, 1, 1, 1, 1, 1>
// clang-format on
>;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_softmax_f32_f32_rank3_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>& instances);
void add_device_softmax_f32_f32_rank4_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4>>& instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f32_f32_rank3_reduce1_instances( void add_device_softmax_f32_f32_rank3_reduce1_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 1>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f32_f32_rank3_reduce2_instances( void add_device_softmax_f32_f32_rank3_reduce2_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 2>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f32_f32_rank3_reduce3_instances( void add_device_softmax_f32_f32_rank3_reduce3_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 3>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment