"tests/vscode:/vscode.git/clone" did not exist on "0c39f53cbb2724b9706a5a8397e8c5ace414aa96"
Commit f26fb605 authored by wangshaojie6's avatar wangshaojie6
Browse files

Merge branch 'develop' into bwd_weight_bf16_splitk

parents 32d06c66 1677cf70
...@@ -24,57 +24,33 @@ template <typename GridwiseGemm, ...@@ -24,57 +24,33 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop>
index_t MaxGroupCount>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdlops_v2r3( kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs, const index_t group_count,
const index_t group_count, const AElementwiseOperation a_element_op,
const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op,
const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
#if 1 const auto gemm_desc_ptr =
static_for<0, MaxGroupCount, 1>{}([&](auto i) { reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_descs[group_id].a_ptr,
gemm_descs[group_id].b_ptr,
gemm_descs[group_id].c_ptr,
p_shared,
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
index_t group_id = 0; index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) { for(index_t i = 0; i < group_count; i++)
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd && {
i < group_count) group_id =
? i (block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_)
: group_id; ? i
}); : group_id;
}
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr, gemm_desc_ptr[group_id].a_ptr,
...@@ -87,11 +63,9 @@ __global__ void ...@@ -87,11 +63,9 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_, gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_);
block_id_grp);
#endif
#else #else
ignore = gemm_descs; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl ...@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl
{ {
grid_size_ = 0; grid_size_ = 0;
gemm_descs_args_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) && if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
...@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl ...@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_; std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_; index_t grid_size_;
}; };
...@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl ...@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
bool has_main_k_block_loop = true; bool has_main_k_block_loop = true;
static_for<0, MaxGroupCount, 1>{}([&](auto i) { for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
if(i < arg.gemm_desc_kernel_arg_.size()) {
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_))
{ {
gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i]; throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
} }
});
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
}
hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));
float ave_time = 0; float ave_time = 0;
...@@ -523,23 +501,23 @@ struct DeviceGroupedGemmXdl ...@@ -523,23 +501,23 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<GemmDescKernelArg>, GemmDescKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
true, true>;
MaxGroupCount>;
ave_time = launch_and_time_kernel(
ave_time = launch_and_time_kernel(stream_config, stream_config,
kernel, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_args, cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
} }
else else
{ {
...@@ -547,23 +525,23 @@ struct DeviceGroupedGemmXdl ...@@ -547,23 +525,23 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<GemmDescKernelArg>, GemmDescKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
false, false>;
MaxGroupCount>;
ave_time = launch_and_time_kernel(
ave_time = launch_and_time_kernel(stream_config, stream_config,
kernel, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_args, cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
} }
return ave_time; return ave_time;
...@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl ...@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl
return str.str(); return str.str();
} }
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
}; };
} // namespace device } // namespace device
......
...@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if constexpr(use_multiblock) if constexpr(use_multiblock)
{ {
const auto zeroVal = const auto identityVal =
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>( ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation); OutMemoryDataOperation);
const auto kernel_pre = const auto kernel_pre =
...@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
0, 0,
out_grid_desc_m_2, out_grid_desc_m_2,
arg.out_dev_, arg.out_dev_,
zeroVal); identityVal);
}; };
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
...@@ -5,14 +30,22 @@ namespace ck { ...@@ -5,14 +30,22 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace binary_element_wise { namespace binary_element_wise {
struct Add template <typename Y, typename X1, typename X2>
struct Add;
template <>
struct Add<double, double, double>
{ {
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const operator()(double& dst, const double& src1, const double& src2) const
{ {
dst = src1 + src2; dst = src1 + src2;
} }
};
template <>
struct Add<float, float, float>
{
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const operator()(float& dst, const float& src1, const float& src2) const
{ {
...@@ -20,6 +53,75 @@ struct Add ...@@ -20,6 +53,75 @@ struct Add
} }
}; };
template <>
struct Add<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 + x2;
dst = ck::type_convert<bhalf_t>(y);
}
};
template <typename Y, typename X1, typename X2>
struct Substract;
template <>
struct Substract<double, double, double>
{
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<float, float, float>
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 - x2;
dst = ck::type_convert<bhalf_t>(y);
}
};
} // namespace binary_element_wise } // namespace binary_element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -143,6 +144,24 @@ struct AddHardswishAdd ...@@ -143,6 +144,24 @@ struct AddHardswishAdd
} }
}; };
struct Normalize
{
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
__host__ __device__ constexpr void operator()(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
{
float variance = mean_square - (mean * mean);
y = ((x - mean) / sqrtf(variance + epsilon_)) * gamma + beta;
}
float epsilon_;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the // Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 // elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
...@@ -278,7 +297,7 @@ struct UnaryAbs<float, float> ...@@ -278,7 +297,7 @@ struct UnaryAbs<float, float>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -286,7 +305,7 @@ struct UnaryAbs<half_t, half_t> ...@@ -286,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); }; __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -294,7 +313,7 @@ struct UnaryAbs<double, double> ...@@ -294,7 +313,7 @@ struct UnaryAbs<double, double>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); }; __host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -302,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t> ...@@ -302,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
}; };
template <typename Y, typename X> template <typename Y, typename X>
...@@ -318,7 +332,7 @@ struct UnarySqrt<float, float> ...@@ -318,7 +332,7 @@ struct UnarySqrt<float, float>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
}; };
template <> template <>
...@@ -326,7 +340,10 @@ struct UnarySqrt<double, double> ...@@ -326,7 +340,10 @@ struct UnarySqrt<double, double>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); }; __host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
}; };
template <typename Y, typename X> template <typename Y, typename X>
......
...@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_value_global) OutDataType* const __restrict__ p_out_value_global)
{ {
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
...@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_idx_buf); in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal; AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
...@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_val_buf(Number<offset>{})); in_thread_val_buf(Number<offset>{}));
}); });
AccDataType tmpValue = zeroVal; AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
......
...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op; (void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
...@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
......
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename Gridwise5AryEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor>
__global__ void kernel_5ary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
Gridwise5AryEltwise::Run(p_a_global,
p_b_global,
p_c_global,
p_d_global,
p_e_global,
p_f_global,
a_grid_desc_m,
b_grid_desc_m,
c_grid_desc_m,
d_grid_desc_m,
e_grid_desc_m,
f_grid_desc_m,
functor);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename ComputeDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Gridwise5AryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * MPerThread);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m.GetElementSpaceSize());
const auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m.GetElementSpaceSize());
const auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_grid_desc_m.GetElementSpaceSize());
const auto e_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_global, e_grid_desc_m.GetElementSpaceSize());
auto f_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_f_global, f_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> d_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> f_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
AGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
BGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_load =
ThreadwiseTensorSliceTransfer_v2<CDataType,
ComputeDataType,
CGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
CScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{c_grid_desc_m, thread_store_global_offset};
auto d_global_load =
ThreadwiseTensorSliceTransfer_v2<DDataType,
ComputeDataType,
DGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
DScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{d_grid_desc_m, thread_store_global_offset};
auto e_global_load =
ThreadwiseTensorSliceTransfer_v2<EDataType,
ComputeDataType,
EGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
EScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{e_grid_desc_m, thread_store_global_offset};
auto f_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
FDataType,
decltype(thread_desc_m),
FGridDesc_M,
PassThrough,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
FScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
f_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = M / (loop_step);
do
{
// read and process MPerThread elements
a_global_load.Run(
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
c_global_load.Run(
c_grid_desc_m, c_global_buf, thread_desc_m, make_tuple(I0), c_thread_buf);
d_global_load.Run(
d_grid_desc_m, d_global_buf, thread_desc_m, make_tuple(I0), d_thread_buf);
e_global_load.Run(
e_grid_desc_m, e_global_buf, thread_desc_m, make_tuple(I0), e_thread_buf);
static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(f_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}),
c_thread_buf(Number<offset>{}),
d_thread_buf(Number<offset>{}),
e_thread_buf(Number<offset>{}));
});
f_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx
f_thread_buf,
f_grid_desc_m,
f_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_load.MoveSrcSliceWindow(c_grid_desc_m, loop_step_index);
d_global_load.MoveSrcSliceWindow(d_grid_desc_m, loop_step_index);
e_global_load.MoveSrcSliceWindow(e_grid_desc_m, loop_step_index);
f_global_write.MoveDstSliceWindow(f_grid_desc_m, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
...@@ -21,7 +21,7 @@ template <typename GridwiseGemm, ...@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation, typename DxsAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const DxsInElementwiseOperation dxs_in_element_op,
const DxsOutElementwiseOperation dxs_out_element_op, const DxsAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -96,7 +96,7 @@ template <typename FloatAB, ...@@ -96,7 +96,7 @@ template <typename FloatAB,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation, typename DxsAccElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation, typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const DxsInElementwiseOperation& dxs_in_element_op, const DxsInElementwiseOperation& dxs_in_element_op,
const DxsOutElementwiseOperation& dxs_out_element_op, const DxsAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal(); const auto d_identityVal = DReduceOperation::GetIdentityValue();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; }); [&](auto I) { d_thread_buf(I) = d_identityVal; });
// reduce in VGPR // reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
......
...@@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
#include <cmath> #include <cmath>
#include "data_type.hpp" #include "data_type.hpp"
#include "half.hpp" #include "type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x) ...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x) static inline __host__ half_t abs(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx); uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
static inline __host__ float isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x) static inline __host__ bool isnan(int8_t x)
{ {
(void)x; (void)x;
return false; return false;
}; };
static inline __host__ int32_t isnan(int32_t x) static inline __host__ bool isnan(int32_t x)
{ {
(void)x; (void)x;
return false; return false;
...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x) ...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x) static inline __host__ bool isnan(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
return half_float::isnan(xx); static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
}; };
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP #define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
...@@ -34,18 +35,6 @@ ...@@ -34,18 +35,6 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <typename T>
static inline __device__ bool is_nan(T x)
{
return (isnan(x));
};
template <>
inline __device__ bool is_nan<half_t>(half_t x)
{
return (__hisnan(x));
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType> template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck; struct AccumulateWithNanCheck;
...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType> ...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{ {
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
}; };
...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> ...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
} }
...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck; ...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
template <typename ReduceOperation, typename AccDataType, typename IndexDataType> template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{ {
__device__ static inline void __host__ __device__ static inline void
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
Calculate(AccDataType& accuVal, Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType ...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
{ {
// The method is called when the ReduceOperation is indexable and the user asked for indices // The method is called when the ReduceOperation is indexable and the user asked for indices
__device__ static inline void Calculate(AccDataType& accuVal, __host__ __device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
IndexDataType& accuIndex, IndexDataType& accuIndex,
IndexDataType currIndex) IndexDataType currIndex)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
accuIndex = currIndex; accuIndex = currIndex;
......
...@@ -36,7 +36,7 @@ namespace reduce { ...@@ -36,7 +36,7 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor // Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least // class must provide at least
// three members: // three members:
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary // 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique // operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements // element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in // when operated against them, and the concept is similar to zero vector in
...@@ -59,7 +59,7 @@ struct Add ...@@ -59,7 +59,7 @@ struct Add
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -76,7 +76,7 @@ struct Mul ...@@ -76,7 +76,7 @@ struct Mul
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -92,7 +92,7 @@ struct Max ...@@ -92,7 +92,7 @@ struct Max
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
...@@ -125,10 +125,7 @@ struct Min ...@@ -125,10 +125,7 @@ struct Min
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
{
return NumericLimits<T>::Max();
};
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -158,7 +155,7 @@ struct AMax ...@@ -158,7 +155,7 @@ struct AMax
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -184,7 +181,7 @@ struct AMax ...@@ -184,7 +181,7 @@ struct AMax
}; };
template <typename T> template <typename T>
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
T result = ck::type_convert<T>(0.0f); T result = ck::type_convert<T>(0.0f);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <limits>
#include <cmath>
#include <functional>
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace host_reduce {
using ck::NanPropagation;
using ck::ReduceTensorOp;
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int)
{
using ck::math::abs;
if constexpr(ReduceOpId == ReduceTensorOp::NORM1)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = a_ * a_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else
{
// ReduceTensorOp::AVG:
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
return ([&](AccDataType&) {});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider)
{
using std::sqrt;
if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = sqrt(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AVG)
{
return ([&, divider](AccDataType& a_) {
a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider));
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
// ReduceTensorOp::AMAX:
return ([&](AccDataType&) {});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn()
{
if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG ||
ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ > b_)
a_ = b_;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ < b_)
a_ = b_;
});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2()
{
if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ > b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ < b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::AVG:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::NORM2:
return (std::function<void(AccDataType&, AccDataType, bool&)>{});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline AccDataType ReduceOpZeroVal()
{
if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return (static_cast<AccDataType>(1.0f));
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return (ck::NumericLimits<AccDataType>::Max());
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX)
{
return (ck::NumericLimits<AccDataType>::Lowest());
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return (static_cast<AccDataType>(0.0f));
}
else
{
// ReduceTensorOp::ADD
// ReduceTensorOp::AVG
// ReduceTensorOp::NORM1
// ReduceTensorOp::NORM2
return (static_cast<AccDataType>(0.0f));
};
};
template <typename AccDataType, bool PropagateNan>
__host__ static inline void
binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
AccDataType& accuVal,
AccDataType currVal)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
opReduce(accuVal, currVal);
}
else
{
if(isnan(currVal))
accuVal = currVal;
else
opReduce(accuVal, currVal);
};
};
template <typename AccDataType, typename IndexDataType, bool PropagateNan>
__host__ static inline void
binop_with_index_and_nan_check(std::function<void(AccDataType&, AccDataType, bool&)> opReduce,
AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
}
else
{
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
}
else
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
};
};
};
}; // namespace host_reduce
}; // namespace ck
#endif
...@@ -33,10 +33,10 @@ ...@@ -33,10 +33,10 @@
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "host_reduce_util.hpp"
#include "host_common_util.hpp" #include "host_common_util.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "reduction_functions_accumulate.hpp"
template <int NDim> template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths, static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
...@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides, ...@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
ck::ReduceTensorOp ReduceOpId, typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
bool PropagateNan, bool PropagateNan,
bool NeedIndices> bool OutputIndex>
struct ReductionHost struct ReductionHost
{ {
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -122,8 +124,6 @@ struct ReductionHost ...@@ -122,8 +124,6 @@ struct ReductionHost
std::vector<int> reduceDims; std::vector<int> reduceDims;
IndexDataType divider; IndexDataType divider;
std::function<void(AccDataType&)> preUnaryOp;
std::function<void(AccDataType&)> posUnaryOp;
std::array<size_t, NumReduceDim> reduceLengths; std::array<size_t, NumReduceDim> reduceLengths;
std::array<size_t, NumReduceDim> reduceStrides; std::array<size_t, NumReduceDim> reduceStrides;
std::array<size_t, NumInvariantDim> invariantLengths; std::array<size_t, NumInvariantDim> invariantLengths;
...@@ -137,9 +137,6 @@ struct ReductionHost ...@@ -137,9 +137,6 @@ struct ReductionHost
const std::vector<int>& invariantDims_, const std::vector<int>& invariantDims_,
const std::vector<int>& reduceDims_) const std::vector<int>& reduceDims_)
{ {
using ck::host_reduce::PosUnaryOpFn;
using ck::host_reduce::PreUnaryOpFn;
// this->outLengths = to_int_vector(outDesc.GetLengths()); // this->outLengths = to_int_vector(outDesc.GetLengths());
this->outStrides = outDesc.GetStrides(); this->outStrides = outDesc.GetStrides();
...@@ -171,9 +168,6 @@ struct ReductionHost ...@@ -171,9 +168,6 @@ struct ReductionHost
invariant_dim_indexes.clear(); invariant_dim_indexes.clear();
get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes); get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes);
}; };
preUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
posUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
}; };
void Run(float alpha, void Run(float alpha,
...@@ -182,7 +176,7 @@ struct ReductionHost ...@@ -182,7 +176,7 @@ struct ReductionHost
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices)
{ {
if constexpr(NeedIndices) if constexpr(OutputIndex)
{ {
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
} }
...@@ -201,15 +195,17 @@ struct ReductionHost ...@@ -201,15 +195,17 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_index_and_nan_check;
using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce2 = ReduceOpFn2<AccDataType, ReduceOpId>(); using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
...@@ -219,15 +215,14 @@ struct ReductionHost ...@@ -219,15 +215,14 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]); auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i); auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -241,7 +236,7 @@ struct ReductionHost ...@@ -241,7 +236,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
auto offset_invariant = auto offset_invariant =
...@@ -255,15 +250,14 @@ struct ReductionHost ...@@ -255,15 +250,14 @@ struct ReductionHost
auto currVal = auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]); type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i); auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -308,15 +302,16 @@ struct ReductionHost ...@@ -308,15 +302,16 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_nan_check;
using ck::host_reduce::ReduceOpFn;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
for(const auto& reduce_index : reduce_dim_indexes) for(const auto& reduce_index : reduce_dim_indexes)
{ {
...@@ -325,12 +320,12 @@ struct ReductionHost ...@@ -325,12 +320,12 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]); auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -343,7 +338,7 @@ struct ReductionHost ...@@ -343,7 +338,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
...@@ -356,12 +351,12 @@ struct ReductionHost ...@@ -356,12 +351,12 @@ struct ReductionHost
auto currVal = auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]); type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
......
...@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[k, m] * b[k, n] // c[m, n] = a[k, m] * b[k, n]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple<
// clang-format off // clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment