Commit f26fb605 authored by wangshaojie6's avatar wangshaojie6
Browse files

Merge branch 'develop' into bwd_weight_bf16_splitk

parents 32d06c66 1677cf70
......@@ -24,57 +24,33 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
index_t MaxGroupCount>
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_descs[group_id].a_ptr,
gemm_descs[group_id].b_ptr,
gemm_descs[group_id].c_ptr,
p_shared,
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
for(index_t i = 0; i < group_count; i++)
{
group_id =
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_)
? i
: group_id;
}
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
......@@ -87,11 +63,9 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_);
#else
ignore = gemm_descs;
ignore = gemm_descs_const;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
......@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl
{
grid_size_ = 0;
gemm_descs_args_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
......@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_;
};
......@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
bool has_main_k_block_loop = true;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < arg.gemm_desc_kernel_arg_.size())
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_))
{
gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i];
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
});
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
}
hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));
float ave_time = 0;
......@@ -523,23 +501,23 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<GemmDescKernelArg>,
GemmDescKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
true,
MaxGroupCount>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
gemm_desc_kernel_args,
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
else
{
......@@ -547,23 +525,23 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<GemmDescKernelArg>,
GemmDescKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
false,
MaxGroupCount>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
gemm_desc_kernel_args,
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
return ave_time;
......@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
};
} // namespace device
......
......@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if constexpr(use_multiblock)
{
const auto zeroVal =
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
const auto identityVal =
ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation);
const auto kernel_pre =
......@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
0,
out_grid_desc_m_2,
arg.out_dev_,
zeroVal);
identityVal);
};
avg_time += launch_and_time_kernel(stream_config,
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#include "data_type.hpp"
......@@ -5,14 +30,22 @@ namespace ck {
namespace tensor_operation {
namespace binary_element_wise {
struct Add
template <typename Y, typename X1, typename X2>
struct Add;
template <>
struct Add<double, double, double>
{
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<float, float, float>
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
......@@ -20,6 +53,75 @@ struct Add
}
};
template <>
struct Add<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 + x2;
dst = ck::type_convert<bhalf_t>(y);
}
};
template <typename Y, typename X1, typename X2>
struct Substract;
template <>
struct Substract<double, double, double>
{
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<float, float, float>
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 - x2;
dst = ck::type_convert<bhalf_t>(y);
}
};
} // namespace binary_element_wise
} // namespace tensor_operation
} // namespace ck
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace tensor_operation {
......@@ -143,6 +144,24 @@ struct AddHardswishAdd
}
};
struct Normalize
{
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
__host__ __device__ constexpr void operator()(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
{
float variance = mean_square - (mean * mean);
y = ((x - mean) / sqrtf(variance + epsilon_)) * gamma + beta;
}
float epsilon_;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
......@@ -278,7 +297,7 @@ struct UnaryAbs<float, float>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
};
template <>
......@@ -286,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
};
template <>
......@@ -294,7 +313,7 @@ struct UnaryAbs<double, double>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
__host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
};
template <>
......@@ -302,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
};
template <typename Y, typename X>
......@@ -318,7 +332,7 @@ struct UnarySqrt<float, float>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
};
template <>
......@@ -326,7 +340,10 @@ struct UnarySqrt<double, double>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); };
__host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
};
template <typename Y, typename X>
......
......@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
......@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
......@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
......@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0;
});
......@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
......@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
......
......@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation,
PropagateNan>;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
......@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
......@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
......@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0;
});
......
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename Gridwise5AryEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor>
__global__ void kernel_5ary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
Gridwise5AryEltwise::Run(p_a_global,
p_b_global,
p_c_global,
p_d_global,
p_e_global,
p_f_global,
a_grid_desc_m,
b_grid_desc_m,
c_grid_desc_m,
d_grid_desc_m,
e_grid_desc_m,
f_grid_desc_m,
functor);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename ComputeDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Gridwise5AryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * MPerThread);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m.GetElementSpaceSize());
const auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m.GetElementSpaceSize());
const auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_grid_desc_m.GetElementSpaceSize());
const auto e_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_global, e_grid_desc_m.GetElementSpaceSize());
auto f_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_f_global, f_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> d_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> f_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
AGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
BGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_load =
ThreadwiseTensorSliceTransfer_v2<CDataType,
ComputeDataType,
CGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
CScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{c_grid_desc_m, thread_store_global_offset};
auto d_global_load =
ThreadwiseTensorSliceTransfer_v2<DDataType,
ComputeDataType,
DGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
DScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{d_grid_desc_m, thread_store_global_offset};
auto e_global_load =
ThreadwiseTensorSliceTransfer_v2<EDataType,
ComputeDataType,
EGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
EScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{e_grid_desc_m, thread_store_global_offset};
auto f_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
FDataType,
decltype(thread_desc_m),
FGridDesc_M,
PassThrough,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
FScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
f_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = M / (loop_step);
do
{
// read and process MPerThread elements
a_global_load.Run(
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
c_global_load.Run(
c_grid_desc_m, c_global_buf, thread_desc_m, make_tuple(I0), c_thread_buf);
d_global_load.Run(
d_grid_desc_m, d_global_buf, thread_desc_m, make_tuple(I0), d_thread_buf);
e_global_load.Run(
e_grid_desc_m, e_global_buf, thread_desc_m, make_tuple(I0), e_thread_buf);
static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(f_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}),
c_thread_buf(Number<offset>{}),
d_thread_buf(Number<offset>{}),
e_thread_buf(Number<offset>{}));
});
f_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx
f_thread_buf,
f_grid_desc_m,
f_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_load.MoveSrcSliceWindow(c_grid_desc_m, loop_step_index);
d_global_load.MoveSrcSliceWindow(d_grid_desc_m, loop_step_index);
e_global_load.MoveSrcSliceWindow(e_grid_desc_m, loop_step_index);
f_global_write.MoveDstSliceWindow(f_grid_desc_m, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
......@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename DxsAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -41,7 +41,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsOutElementwiseOperation dxs_out_element_op,
const DxsAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -96,7 +96,7 @@ template <typename FloatAB,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename DxsAccElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
......@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const DxsInElementwiseOperation& dxs_in_element_op,
const DxsOutElementwiseOperation& dxs_out_element_op,
const DxsAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
......@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>;
// Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal();
const auto d_identityVal = DReduceOperation::GetIdentityValue();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
[&](auto I) { d_thread_buf(I) = d_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
......
......@@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......
......@@ -3,11 +3,13 @@
#include <cmath>
#include "data_type.hpp"
#include "half.hpp"
#include "type.hpp"
namespace ck {
namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); };
......@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x)
{
half_float::half xx = *reinterpret_cast<half_float::half*>(&x);
uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx);
uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx);
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x;
};
static inline __host__ float isnan(float x) { return std::isnan(x); };
static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); };
static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x)
static inline __host__ bool isnan(int8_t x)
{
(void)x;
return false;
};
static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(int32_t x)
{
(void)x;
return false;
......@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x)
{
half_float::half xx = *reinterpret_cast<half_float::half*>(&x);
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
return half_float::isnan(xx);
static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
};
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math
} // namespace ck
......
......@@ -27,6 +27,7 @@
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp"
#include "math_v2.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
......@@ -34,18 +35,6 @@
namespace ck {
namespace detail {
template <typename T>
static inline __device__ bool is_nan(T x)
{
return (isnan(x));
};
template <>
inline __device__ bool is_nan<half_t>(half_t x)
{
return (__hisnan(x));
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck;
......@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{
// cppcheck-suppress constParameter
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
__host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{
ReduceOperation{}(accuVal, currVal);
};
......@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
__host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{
if(is_nan(currVal))
using ck::math::isnan;
if(isnan(currVal))
{
accuVal = currVal;
}
......@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{
__device__ static inline void
__host__ __device__ static inline void
// cppcheck-suppress constParameter
Calculate(AccDataType& accuVal,
AccDataType currVal,
......@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
{
// The method is called when the ReduceOperation is indexable and the user asked for indices
__device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
__host__ __device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
{
if(is_nan(currVal))
using ck::math::isnan;
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
......
......@@ -36,7 +36,7 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in
......@@ -59,7 +59,7 @@ struct Add
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -76,7 +76,7 @@ struct Mul
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -92,7 +92,7 @@ struct Max
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal()
__host__ __device__ static constexpr T GetIdentityValue()
{
return NumericLimits<T>::Lowest();
};
......@@ -125,10 +125,7 @@ struct Min
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal()
{
return NumericLimits<T>::Max();
};
__host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -158,7 +155,7 @@ struct AMax
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -184,7 +181,7 @@ struct AMax
};
template <typename T>
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
T result = ck::type_convert<T>(0.0f);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <limits>
#include <cmath>
#include <functional>
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace host_reduce {
using ck::NanPropagation;
using ck::ReduceTensorOp;
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int)
{
using ck::math::abs;
if constexpr(ReduceOpId == ReduceTensorOp::NORM1)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = a_ * a_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else
{
// ReduceTensorOp::AVG:
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
return ([&](AccDataType&) {});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider)
{
using std::sqrt;
if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = sqrt(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AVG)
{
return ([&, divider](AccDataType& a_) {
a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider));
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
// ReduceTensorOp::AMAX:
return ([&](AccDataType&) {});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn()
{
if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG ||
ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ > b_)
a_ = b_;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ < b_)
a_ = b_;
});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2()
{
if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ > b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ < b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::AVG:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::NORM2:
return (std::function<void(AccDataType&, AccDataType, bool&)>{});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline AccDataType ReduceOpZeroVal()
{
if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return (static_cast<AccDataType>(1.0f));
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return (ck::NumericLimits<AccDataType>::Max());
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX)
{
return (ck::NumericLimits<AccDataType>::Lowest());
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return (static_cast<AccDataType>(0.0f));
}
else
{
// ReduceTensorOp::ADD
// ReduceTensorOp::AVG
// ReduceTensorOp::NORM1
// ReduceTensorOp::NORM2
return (static_cast<AccDataType>(0.0f));
};
};
template <typename AccDataType, bool PropagateNan>
__host__ static inline void
binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
AccDataType& accuVal,
AccDataType currVal)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
opReduce(accuVal, currVal);
}
else
{
if(isnan(currVal))
accuVal = currVal;
else
opReduce(accuVal, currVal);
};
};
template <typename AccDataType, typename IndexDataType, bool PropagateNan>
__host__ static inline void
binop_with_index_and_nan_check(std::function<void(AccDataType&, AccDataType, bool&)> opReduce,
AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
}
else
{
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
}
else
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
};
};
};
}; // namespace host_reduce
}; // namespace ck
#endif
......@@ -33,10 +33,10 @@
#include "reduction_enums.hpp"
#include "reduction_common.hpp"
#include "host_reduce_util.hpp"
#include "host_common_util.hpp"
#include "host_tensor.hpp"
#include "data_type.hpp"
#include "reduction_functions_accumulate.hpp"
template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
......@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template <typename InDataType,
typename AccDataType,
typename OutDataType,
ck::ReduceTensorOp ReduceOpId,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
int Rank,
int NumReduceDim,
bool PropagateNan,
bool NeedIndices>
bool OutputIndex>
struct ReductionHost
{
using IndexDataType = int32_t;
......@@ -122,8 +124,6 @@ struct ReductionHost
std::vector<int> reduceDims;
IndexDataType divider;
std::function<void(AccDataType&)> preUnaryOp;
std::function<void(AccDataType&)> posUnaryOp;
std::array<size_t, NumReduceDim> reduceLengths;
std::array<size_t, NumReduceDim> reduceStrides;
std::array<size_t, NumInvariantDim> invariantLengths;
......@@ -137,9 +137,6 @@ struct ReductionHost
const std::vector<int>& invariantDims_,
const std::vector<int>& reduceDims_)
{
using ck::host_reduce::PosUnaryOpFn;
using ck::host_reduce::PreUnaryOpFn;
// this->outLengths = to_int_vector(outDesc.GetLengths());
this->outStrides = outDesc.GetStrides();
......@@ -171,9 +168,6 @@ struct ReductionHost
invariant_dim_indexes.clear();
get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes);
};
preUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
posUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
};
void Run(float alpha,
......@@ -182,7 +176,7 @@ struct ReductionHost
OutDataType* out_data,
IndexDataType* out_indices)
{
if constexpr(NeedIndices)
if constexpr(OutputIndex)
{
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
}
......@@ -201,15 +195,17 @@ struct ReductionHost
using ck::float_equal_one;
using ck::float_equal_zero;
using ck::type_convert;
using ck::host_reduce::binop_with_index_and_nan_check;
using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce2 = ReduceOpFn2<AccDataType, ReduceOpId>();
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
......@@ -219,15 +215,14 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce2, accuVal, currVal, accuIndex, currIndex);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
posUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
......@@ -241,7 +236,7 @@ struct ReductionHost
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0;
auto offset_invariant =
......@@ -255,15 +250,14 @@ struct ReductionHost
auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce2, accuVal, currVal, accuIndex, currIndex);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
posUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
......@@ -308,15 +302,16 @@ struct ReductionHost
using ck::float_equal_one;
using ck::float_equal_zero;
using ck::type_convert;
using ck::host_reduce::binop_with_nan_check;
using ck::host_reduce::ReduceOpFn;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
AccDataType accuVal = ReduceOperation::GetIdentityValue();
for(const auto& reduce_index : reduce_dim_indexes)
{
......@@ -325,12 +320,12 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
posUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
......@@ -343,7 +338,7 @@ struct ReductionHost
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
AccDataType accuVal = ReduceOperation::GetIdentityValue();
auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
......@@ -356,12 +351,12 @@ struct ReductionHost
auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
posUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// FIXME: support arbitrary elementwise operation for A/B/C
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct ReferenceCGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k_real,
const Tensor<ADataType>& a_m_k_imag,
const Tensor<BDataType>& b_k_n_real,
const Tensor<BDataType>& b_k_n_imag,
Tensor<CDataType>& c_m_n_real,
Tensor<CDataType>& c_m_n_imag,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_real_{a_m_k_real},
a_m_k_imag_{a_m_k_imag},
b_k_n_real_{b_k_n_real},
b_k_n_imag_{b_k_n_imag},
c_m_n_real_{c_m_n_real},
c_m_n_imag_{c_m_n_imag},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_real_;
const Tensor<ADataType>& a_m_k_imag_;
const Tensor<BDataType>& b_k_n_real_;
const Tensor<BDataType>& b_k_n_imag_;
Tensor<CDataType>& c_m_n_real_;
Tensor<CDataType>& c_m_n_imag_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceCGemm::Argument;
float Run(const Argument& arg)
{
const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
{
throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM");
}
auto f_mk_kn_mn_real = [&](auto m, auto n) {
float v_c_real = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
}
arg.c_m_n_real_(m, n) = v_c_real;
};
auto f_mk_kn_mn_imag = [&](auto m, auto n) {
float v_c_imag = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
}
arg.c_m_n_imag_(m, n) = v_c_imag;
};
make_ParallelTensorFunctor(f_mk_kn_mn_real,
arg.c_m_n_real_.mDesc.GetLengths()[0],
arg.c_m_n_real_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_mk_kn_mn_imag,
arg.c_m_n_imag_.mDesc.GetLengths()[0],
arg.c_m_n_imag_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k_real,
const Tensor<ADataType>& a_m_k_imag,
const Tensor<BDataType>& b_k_n_real,
const Tensor<BDataType>& b_k_n_imag,
Tensor<CDataType>& c_m_n_real,
Tensor<CDataType>& c_m_n_imag,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k_real,
a_m_k_imag,
b_k_n_real,
b_k_n_imag,
c_m_n_real,
c_m_n_imag,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceCGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
......@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
......@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
......@@ -38,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
......@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
......@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[k, m] * b[k, n]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple<
// clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment