Commit 10962ae6 authored by rocking's avatar rocking
Browse files

Merge commit '1677cf70' into gemm_add_bias_reduction

parents 300ac4e4 1677cf70
...@@ -212,30 +212,50 @@ def runCKProfiler(Map conf=[:]){ ...@@ -212,30 +212,50 @@ def runCKProfiler(Map conf=[:]){
{ {
cmake_build(conf) cmake_build(conf)
dir("script"){ dir("script"){
def perf_log = "perf_gemm_${gpu_arch}.log" //run gemm performance tests
sh "rm -f ${perf_log}" def gemm_log = "perf_gemm_${gpu_arch}.log"
sh "echo Branch name: ${env.BRANCH_NAME} > ${perf_log}" sh "rm -f ${gemm_log}"
sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${perf_log}" sh "echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}"
sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${perf_log}" sh "echo Node name: ${NODE_NAME} >> ${gemm_log}"
sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${perf_log}" sh "echo GPU_arch: ${gpu_arch} >> ${gemm_log}"
sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${perf_log}" sh "hipcc --version | grep -e 'HIP version' >> ${gemm_log}"
sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${perf_log}" sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}"
sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${perf_log}" sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}"
sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}"
//results will be parsed, stored, and analyzed within the python script //results will be parsed, stored, and analyzed within the python script
//the script will return 0 if the performance criteria are met //the script will return 0 if the performance criteria are met
//or return 1 if the criteria are not met //or return 1 if the criteria are not met
archiveArtifacts "${perf_log}" archiveArtifacts "${gemm_log}"
sh "python3 parse_perf_data.py ${perf_log} " sh "python3 parse_perf_data.py ${gemm_log} "
//run resnet50 test
def resnet_log = "perf_resnet50_${gpu_arch}.log"
sh "rm -f ${resnet_log}"
sh "echo Branch name: ${env.BRANCH_NAME} > ${resnet_log}"
sh "echo Node name: ${NODE_NAME} >> ${resnet_log}"
sh "echo GPU_arch: ${gpu_arch} >> ${resnet_log}"
sh "hipcc --version | grep -e 'HIP version' >> ${resnet_log}"
sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log}"
//first run tests with N=256
sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log}"
//then run with N=4
sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log}"
archiveArtifacts "${resnet_log}"
//the script will put the results from N=256 and N=4 runs into separate tables
sh "python3 parse_perf_data.py ${resnet_log} "
} }
} }
} }
......
...@@ -224,10 +224,10 @@ int main(int argc, char* argv[]) ...@@ -224,10 +224,10 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
break; break;
default: default:
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
......
...@@ -147,8 +147,6 @@ class SimpleAppArgs ...@@ -147,8 +147,6 @@ class SimpleAppArgs
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
const std::vector<int> reduceDims{0, 1, 2}; const std::vector<int> reduceDims{0, 1, 2};
const std::vector<int> invariantDims{3}; const std::vector<int> invariantDims{3};
...@@ -254,7 +252,9 @@ int main(int argc, char* argv[]) ...@@ -254,7 +252,9 @@ int main(int argc, char* argv[])
ReductionHost<InDataType, ReductionHost<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank, Rank,
NumReduceDim, NumReduceDim,
PropagateNan, PropagateNan,
......
...@@ -108,8 +108,6 @@ int main(int argc, char* argv[]) ...@@ -108,8 +108,6 @@ int main(int argc, char* argv[])
const std::vector<size_t> outLengths = {64, 320, 80}; const std::vector<size_t> outLengths = {64, 320, 80};
using namespace ck::host_reduce;
if(argc == 1) if(argc == 1)
{ {
do_verify = true; do_verify = true;
...@@ -191,7 +189,9 @@ int main(int argc, char* argv[]) ...@@ -191,7 +189,9 @@ int main(int argc, char* argv[])
ReductionHost<InOutDataType, ReductionHost<InOutDataType,
AccDataType, AccDataType,
InOutDataType, InOutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
5, // Rank 5, // Rank
2, // NumReduceDim 2, // NumReduceDim
PropagateNan, PropagateNan,
......
...@@ -8,10 +8,12 @@ ...@@ -8,10 +8,12 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "reduction_functions_accumulate.hpp"
#include "device_pool2d_fwd_nhwc_nhwc.hpp" #include "device_pool2d_fwd_nhwc_nhwc.hpp"
template <typename InDataType, template <typename InDataType,
...@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in,
const std::array<ck::index_t, 2>& in_left_pads, const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/) const std::array<ck::index_t, 2>& /*in_right_pads*/)
{ {
using namespace ck::host_reduce;
const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider); using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType;
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider); using InElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
const InElementwiseOperation in_elementwise_op(divider);
const AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(!OutputIndex) if constexpr(!OutputIndex)
{ {
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); auto accuVal = ReduceOperation::GetIdentityValue();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{ {
...@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
PreUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
} }
} }
} }
PosUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal; out(n, c, ho, wo) = accuVal;
}; };
...@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in,
} }
else else
{ {
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>(); using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); auto accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
...@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x; IndexDataType currIndex = y * window_spatial_lengths[1] + x;
PreUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce, accuVal, currVal, accuIndex, currIndex);
} }
} }
} }
PosUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal; out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex; out_indices(n, c, ho, wo) = accuIndex;
...@@ -139,8 +147,6 @@ bool pool_test(bool do_verification, ...@@ -139,8 +147,6 @@ bool pool_test(bool do_verification,
ck::index_t in_right_pad_h, ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w) ck::index_t in_right_pad_w)
{ {
using namespace ck::host_reduce;
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType InDataType, // InDataType
......
...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false; ...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
bool do_verification; bool do_verification;
int init_method; int init_method;
bool time_kernel; bool time_kernel;
......
...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false; ...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
bool do_verification; bool do_verification;
int init_method; int init_method;
bool time_kernel; bool time_kernel;
......
...@@ -236,7 +236,7 @@ int main(int argc, char* argv[]) ...@@ -236,7 +236,7 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal(); ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
d_reduce_op(d_acc, c_m_n_host_result(m, n)); d_reduce_op(d_acc, c_m_n_host_result(m, n));
......
...@@ -261,8 +261,8 @@ int main(int argc, char* argv[]) ...@@ -261,8 +261,8 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -259,8 +259,8 @@ int main(int argc, char* argv[]) ...@@ -259,8 +259,8 @@ int main(int argc, char* argv[])
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
auto reduceSumOpInst = ReduceSumOp{}; auto reduceSumOpInst = ReduceSumOp{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float mean_acc = reduceSumOpInst.GetReductionZeroVal(); float mean_acc = reduceSumOpInst.GetIdentityValue();
float square_mean_acc = reduceSumOpInst.GetReductionZeroVal(); float square_mean_acc = reduceSumOpInst.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -460,6 +460,8 @@ struct ...@@ -460,6 +460,8 @@ struct
using C0GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I3])>; using C0GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I3])>;
using C1GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I4])>; using C1GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I4])>;
using Block2CTileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, CGridDesc_M_N>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3<
BlockSize, BlockSize,
...@@ -522,8 +524,6 @@ struct ...@@ -522,8 +524,6 @@ struct
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
ck::index_t M01,
ck::index_t N01,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op)
...@@ -540,10 +540,7 @@ struct ...@@ -540,10 +540,7 @@ struct
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
block_2_ctile_map_{ block_2_ctile_map_{},
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01},
N01_{N01},
in_element_op_{in_element_op}, in_element_op_{in_element_op},
wei_element_op_{wei_element_op}, wei_element_op_{wei_element_op},
out_element_op_{out_element_op}, out_element_op_{out_element_op},
...@@ -576,6 +573,8 @@ struct ...@@ -576,6 +573,8 @@ struct
c0_grid_desc_m_n_ = descs[I3]; c0_grid_desc_m_n_ = descs[I3];
c1_grid_desc_m_n_ = descs[I4]; c1_grid_desc_m_n_ = descs[I4];
block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_};
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
...@@ -618,9 +617,7 @@ struct ...@@ -618,9 +617,7 @@ struct
typename GridwiseGemm:: typename GridwiseGemm::
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_; OutElementwiseOperation out_element_op_;
...@@ -723,7 +720,7 @@ struct ...@@ -723,7 +720,7 @@ struct
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, Block2CTileMap,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -767,7 +764,7 @@ struct ...@@ -767,7 +764,7 @@ struct
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, Block2CTileMap,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -894,8 +891,6 @@ struct ...@@ -894,8 +891,6 @@ struct
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
1,
1,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op}; out_element_op};
...@@ -938,8 +933,6 @@ struct ...@@ -938,8 +933,6 @@ struct
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
1,
1,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
......
...@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if constexpr(use_multiblock) if constexpr(use_multiblock)
{ {
const auto zeroVal = const auto identityVal =
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>( ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation); OutMemoryDataOperation);
const auto kernel_pre = const auto kernel_pre =
...@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
0, 0,
out_grid_desc_m_2, out_grid_desc_m_2,
arg.out_dev_, arg.out_dev_,
zeroVal); identityVal);
}; };
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
......
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -311,7 +312,7 @@ struct UnaryAbs<float, float> ...@@ -311,7 +312,7 @@ struct UnaryAbs<float, float>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -319,7 +320,7 @@ struct UnaryAbs<half_t, half_t> ...@@ -319,7 +320,7 @@ struct UnaryAbs<half_t, half_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); }; __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -327,7 +328,7 @@ struct UnaryAbs<double, double> ...@@ -327,7 +328,7 @@ struct UnaryAbs<double, double>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); }; __host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -335,12 +336,7 @@ struct UnaryAbs<int8_t, int8_t> ...@@ -335,12 +336,7 @@ struct UnaryAbs<int8_t, int8_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
}; };
template <typename Y, typename X> template <typename Y, typename X>
...@@ -351,7 +347,7 @@ struct UnarySqrt<float, float> ...@@ -351,7 +347,7 @@ struct UnarySqrt<float, float>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
}; };
template <> template <>
...@@ -359,7 +355,10 @@ struct UnarySqrt<double, double> ...@@ -359,7 +355,10 @@ struct UnarySqrt<double, double>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); }; __host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
}; };
} // namespace element_wise } // namespace element_wise
......
...@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_value_global) OutDataType* const __restrict__ p_out_value_global)
{ {
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
...@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_idx_buf); in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal; AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
...@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_val_buf(Number<offset>{})); in_thread_val_buf(Number<offset>{}));
}); });
AccDataType tmpValue = zeroVal; AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
......
...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op; (void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
...@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
......
...@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal(); const auto d_identityVal = DReduceOperation::GetIdentityValue();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; }); [&](auto I) { d_thread_buf(I) = d_identityVal; });
// reduce in VGPR // reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
......
...@@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
#include <cmath> #include <cmath>
#include "data_type.hpp" #include "data_type.hpp"
#include "half.hpp" #include "type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x) ...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x) static inline __host__ half_t abs(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx); uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
static inline __host__ float isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x) static inline __host__ bool isnan(int8_t x)
{ {
(void)x; (void)x;
return false; return false;
}; };
static inline __host__ int32_t isnan(int32_t x) static inline __host__ bool isnan(int32_t x)
{ {
(void)x; (void)x;
return false; return false;
...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x) ...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x) static inline __host__ bool isnan(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
return half_float::isnan(xx); static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
}; };
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP #define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
...@@ -34,18 +35,6 @@ ...@@ -34,18 +35,6 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <typename T>
static inline __device__ bool is_nan(T x)
{
return (isnan(x));
};
template <>
inline __device__ bool is_nan<half_t>(half_t x)
{
return (__hisnan(x));
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType> template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck; struct AccumulateWithNanCheck;
...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType> ...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{ {
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
}; };
...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> ...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
} }
...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck; ...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
template <typename ReduceOperation, typename AccDataType, typename IndexDataType> template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{ {
__device__ static inline void __host__ __device__ static inline void
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
Calculate(AccDataType& accuVal, Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType ...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
{ {
// The method is called when the ReduceOperation is indexable and the user asked for indices // The method is called when the ReduceOperation is indexable and the user asked for indices
__device__ static inline void Calculate(AccDataType& accuVal, __host__ __device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
IndexDataType& accuIndex, IndexDataType& accuIndex,
IndexDataType currIndex) IndexDataType currIndex)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
accuIndex = currIndex; accuIndex = currIndex;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment