Commit defa2071 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into aosewski/ggemm_multi_d2

parents 28a68428 f2398f61
......@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp"
......@@ -46,14 +46,14 @@ template <typename XDataType,
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>
struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>
{
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert(
......@@ -461,7 +461,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationImpl<" << BlockSize << ",";
str << "DeviceNormalizationFwdImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
......
......@@ -8,7 +8,7 @@
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
......@@ -134,14 +134,14 @@ template <typename XDataType,
index_t BetaSrcVectorSize,
index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize>
struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>
struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation,
Rank,
NumReduceDim>
{
using WorkspaceMeanVarDataType = SaveMeanInvStdDataType;
......@@ -732,7 +732,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationSplitKImpl<" << BlockSize << ",";
str << "DeviceNormalizationFwdSplitKImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYVectorDim << ",";
......
......@@ -85,10 +85,13 @@ struct Add
struct ScaleAdd
{
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = ck::type_convert<Y>(scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
}
template <>
__host__ __device__ void
......@@ -186,6 +189,25 @@ struct Bilinear
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x0_tmp = type_convert<float>(x0);
const float x1_tmp = type_convert<float>(x1);
const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
y = type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
y = y_tmp;
};
template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
......
......@@ -311,6 +311,71 @@ struct AddAddFastGelu
}
};
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct ScaleAddScaleAddRelu
{
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
: alpha1_(alpha1), alpha2_(alpha2)
{
}
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float x = c * alpha1_ + alpha2_ * d0 + d1;
Relu{}.template operator()<float>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<half_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<bhalf_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
int8_t& e, const int8_t& c, const float& d0, const float& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<int8_t>(result);
}
const float alpha1_;
const float alpha2_;
};
struct Normalize
{
// FIXME: is double absolutely necessary?
......
......@@ -16,6 +16,57 @@ namespace element_wise {
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
struct PassThroughPack2
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const
{
// fake conversion
uint16_t t = ck::bit_cast<uint32_t>(x);
y = ck::bit_cast<ck::f8x2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const
{
y = x;
}
constexpr const static bool is_pack2_invocable = true;
};
struct PassThrough
{
template <typename Y, typename X>
......@@ -33,6 +84,12 @@ struct PassThrough
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -69,6 +126,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<float, bhalf_t>(float& y, const bhalf_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
......@@ -207,7 +270,8 @@ struct ConvertF8SR
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
static_assert(is_same<Y, f8_t>::value || is_same<Y, bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
......@@ -224,6 +288,20 @@ struct Scale
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = ck::type_convert<half_t>(scale_) * x;
};
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
const float x_tmp = ck::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -277,8 +355,8 @@ struct UnarySquare
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, int32_t> ||
is_same_v<T, int8_t>
static_assert(is_same_v<T, float> || is_same_v<T, half_t> || is_same_v<T, double> ||
is_same_v<T, int32_t> || is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| is_same_v<T, int4_t>
#endif
......@@ -442,10 +520,11 @@ struct Sigmoid
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = 1 / (ck::type_convert<T>(1) + exp(-x));
constexpr T one = type_convert<T>(1);
y = one / (one + ck::math::exp(-x));
};
};
......@@ -455,7 +534,8 @@ struct TanH
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tanh(x);
......@@ -481,7 +561,101 @@ struct Swish
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
};
float beta_ = 1.0f;
const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
const float alpha_;
};
} // namespace element_wise
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise1dFunctor,
typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale>
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op,
const Scale scale_op)
{
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
unary_op,
scale_op);
}
template <typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_1D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid1dDescTuple::Size() &&
NumOutput == OutGrid1dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op,
const Scale scale_op)
{
const index_t thread_global_id = get_thread_global_1d_id();
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
decltype(in_grid_1d_desc_tuple[I]),
decltype(thread_buffer_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(
I), // ScalarPerVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc_tuple[I],
thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<DataType,
DataType,
decltype(thread_buffer_desc_m),
decltype(out_grid_1d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter = M / (loop_step);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_m,
make_tuple(I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
loop_step_index);
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
// get reference to in data
auto uop_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
// get reference to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); },
Number<NumOutput>{});
unpack2(unary_op, uop_data_refs, uop_data_refs);
auto sop_in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
auto sop_out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_m,
make_tuple(I0),
out_thread_buf_tuple[I],
out_grid_1d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
loop_step_index);
});
} while(--num_iter);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise3dFunctor,
typename InGrid3dDescTuple,
typename OutGrid3dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation>
__global__ void kernel_elementwise_3d(const InGrid3dDescTuple in_grid_3d_desc_tuple,
const OutGrid3dDescTuple out_grid_3d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
{
GridwiseElementwise3dFunctor::Run(in_grid_3d_desc_tuple,
out_grid_3d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
num_threads_m,
num_threads_n,
num_threads_k);
}
template <typename InGrid3dDescTuple,
typename OutGrid3dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_3D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid3dDescTuple::Size() &&
NumOutput == OutGrid3dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto thread_buffer_desc_mnk = make_naive_tensor_descriptor_packed(
make_tuple(Number<MPerThread>{}, Number<NPerThread>{}, Number<KPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid3dDescTuple in_grid_3d_desc_tuple,
const OutGrid3dDescTuple out_grid_3d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
{
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread * KPerThread,
true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread * KPerThread,
true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_3d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_3d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto M = in_grid_3d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_3d_desc_tuple[I0].GetLength(I1);
const auto K = in_grid_3d_desc_tuple[I0].GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
const index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k);
const index_t tid_nk = thread_1d_id % (num_threads_n * num_threads_k);
const index_t tid_n = tid_nk / num_threads_k;
const index_t tid_k = tid_nk % num_threads_k;
const auto thread_global_offset =
make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
decltype(in_grid_3d_desc_tuple[I]),
decltype(thread_buffer_desc_mnk),
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
01, // SrcVectorDim
InScalarPerVectorSeq::At(I), // InScalarPerVectorSeq::At(I), //
// ScalarPerVector
1, // SrcScalarStrideInVector
true>{in_grid_3d_desc_tuple[I], thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<
DataType,
DataType,
decltype(thread_buffer_desc_mnk),
decltype(out_grid_3d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
OutScalarPerVectorSeq::At(I), // OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
true>(out_grid_3d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter_m = M / (loop_step_m);
do
{
index_t num_iter_n = N / (loop_step_n);
do
{
index_t num_iter_k = K / (loop_step_k);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_3d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) {
static_for<0, KPerThread, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_mnk.CalculateOffset(make_tuple(iM, iN, iK));
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& {
return in_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumInput>{});
// get referenec to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& {
return out_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
});
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
out_thread_buf_tuple[I],
out_grid_3d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
} while(--num_iter_k);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_n);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_m);
}
};
} // namespace ck
......@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
......@@ -219,17 +219,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename AsGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
{
return generate_tuple(
[&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
[&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
Number<NumATensor>{});
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
......@@ -245,10 +245,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BsGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
{
return generate_tuple(
[&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
[&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
Number<NumBTensor>{});
}
......@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
......@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
......@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
......@@ -1005,9 +1011,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
......
......@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t K0;
index_t K0Padded;
index_t k_batch;
Argument(const FloatA* p_a_grid_,
......@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t K0Padded_,
index_t k_batch_)
: p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_),
......@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MPadded(MPadded_),
NPadded(NPadded_),
KPadded(KPadded_),
K0(K0_),
K0Padded(K0Padded_),
k_batch(k_batch_)
{
}
......@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "K0:" << K0 << ", "
<< "K0Padded:" << K0Padded << ", "
<< "KB:" << k_batch << "}" << std::endl;
}
};
......@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return math::integer_least_multiple(N, NPerBlock);
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
......@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
auto K0Padded = CalculateK0Padded(K, K_Batch);
return K_Batch * K0Padded * K1;
}
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
......@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t K,
index_t StrideA,
index_t KBatch,
index_t K0,
index_t K0Padded,
index_t KPad)
{
const auto a_grid_desc_m_k = [&]() {
......@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t N,
index_t StrideB,
index_t KBatch,
index_t K0,
index_t K0Padded,
index_t KPad)
{
const auto b_grid_desc_k_n = [&]() {
......@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
......@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = karg.k_batch * K0PerBlock * K1;
if(!(karg.K % K_t == 0))
{
#if DEBUG_LOG
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
......@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{
#if DEBUG_LOG
std::cout
<< "Arg N (" << karg.N
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
......@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{
#if DEBUG_LOG
std::cout
<< "Arg M (" << karg.M
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
const auto num_k_loop = karg.K0 / K0PerBlock;
const auto num_k_loop = karg.K0Padded / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
#if DEBUG_LOG
std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline."
<< " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
<< " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
......@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
const index_t K0Padded =
math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0Padded * K1;
return KPad;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
{
const index_t num_loop = K0 / K0PerBlock;
const index_t num_loop = K0Padded / K0PerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
......@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const FloatB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
......
......@@ -21,6 +21,7 @@ template <typename InputGridDesc,
typename OutputGridDesc,
typename OutputDataType,
typename Block2ETileMap,
typename ComputePtrOffsetOfStridedBatch,
typename GridwiseTensorRearrangeKernel>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -30,13 +31,20 @@ __global__ void
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map)
const index_t batch_count,
const Block2ETileMap block_2_tile_map,
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel::Run(
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
p_in_global,
out_grid_desc,
p_out_global,
batch_count,
block_2_tile_map,
compute_ptr_offset_of_batch);
#else
ignore = in_grid_desc;
ignore = p_in_global;
......@@ -56,7 +64,8 @@ template <typename InputGridDesc,
typename ThreadClusterLengths,
index_t ScalarPerVector,
InMemoryDataOperationEnum DstInMemOp,
typename Block2ETileMap>
typename Block2ETileMap,
typename ComputePtrOffsetOfStridedBatch>
struct GridwiseTensorRearrange
{
......@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc& out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap& block_2_tile_map)
const index_t batch_count,
const Block2ETileMap& block_2_tile_map,
const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
{
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
// Global Memory
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
auto copy_global_to_global =
ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock,
Tuple<InputDataType>,
......@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}};
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// Global Memory
const index_t a_batch_offset =
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const index_t c_batch_offset =
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
copy_global_to_global.Run(
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace ck {
// dgamma = reduce_sum(dy * (x - mean) * inv_std)
// dbeta = reduce_sum(dy)
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DGammaDataType,
typename DBetaDataType,
typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t DYSrcVectorDim,
index_t DYSrcVectorSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t MeanInvStdSrcVectorDim,
index_t MeanInvStdSrcVectorSize,
index_t DGammaDstVectorSize,
index_t DBetaDstVectorSize>
struct GridwiseNormalizationBwdGammaBeta_mk_to_k
{
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder =
typename conditional<DYSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using XThreadBufferDimAccessOrder =
typename conditional<XSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using MeanInvStdThreadBufferDimAccessOrder =
typename conditional<MeanInvStdSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using BlockwiseSumReduce = PartitionedBlockwiseReduction<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
true>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& mean_grid_desc_m_k,
const GridDesc_M_K& inv_std_grid_desc_m_k,
const GridDesc_M& dgamma_grid_desc_m,
const GridDesc_M& dbeta_grid_desc_m,
index_t num_k_block_tile_iteration,
const DYDataType* const __restrict__ p_dy_global,
const XDataType* const __restrict__ p_x_global,
const MeanInvStdDataType* const __restrict__ p_mean_global,
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
DGammaDataType* const __restrict__ p_dgamma_global,
DBetaDataType* const __restrict__ p_dbeta_global)
{
// LDS
__shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
// Global
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
auto dgamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize());
auto dbeta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize());
// VGPR
auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto dgamma_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
auto dbeta_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
// IO
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
DYThreadBufferDimAccessOrder,
DYSrcVectorDim,
DYSrcVectorSize,
1,
true>(
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
XThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_mean_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
true>(
mean_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_inv_std_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
true>(
inv_std_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dgamma_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
DGammaDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
DGammaDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dgamma_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbeta_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
DBetaDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
DBetaDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
dbeta_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
dgamma_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
dbeta_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_mean_load.Run(mean_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
mean_thread_buf);
threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
inv_std_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
inv_std_thread_buf);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
dgamma_thread_buf(offset_m) +=
dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
(x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]);
dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k];
});
});
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I));
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_dgamma_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dgamma_thread_buf,
dgamma_grid_desc_m,
dgamma_global_val_buf);
threadwise_dbeta_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbeta_thread_buf,
dbeta_grid_desc_m,
dbeta_global_val_buf);
}
}
};
} // namespace ck
......@@ -1075,6 +1075,7 @@ struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
......@@ -1091,6 +1092,7 @@ struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
......@@ -1107,6 +1109,8 @@ struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
static constexpr int bias = 8; // negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
template <>
......@@ -1114,6 +1118,7 @@ struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
//
} // namespace ck
......@@ -5,7 +5,6 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
namespace ck {
// fp8 rounding modes
......@@ -17,6 +16,9 @@ enum class f8_rounding_mode
stochastic
};
__host__ inline int clz(uint32_t x) { return __builtin_clz(x); }
__device__ inline int clz(uint32_t x) { return __clz(x); }
} // namespace ck
namespace ck::utils {
......@@ -34,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
constexpr int in_exp = NumericUtils<X>::exp;
constexpr int in_mant = NumericUtils<X>::mant;
int exponent;
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
......@@ -49,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
mantissa = x_bitwise & NumericUtils<X>::mant_mask;
exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
sign = head >> (in_exp + in_mant);
bias = NumericUtils<X>::bias;
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff =
(1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan)
{
......@@ -67,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
// if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{
exponent += 1;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << in_mant);
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << in_mant;
// apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << in_mant))
{
mantissa >>= 1;
exponent++;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, out_exponent, exponent_diff;
if(exponent == 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = out_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= out_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = out_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
}
mantissa >>= (in_mant - out_mant);
// check negative exponent
if(exponent <= 0)
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
(1 << (in_mant - out_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << in_mant);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
bool odd =
mantissa &
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if(out_exponent == 0)
{
if(x_bitwise == 0)
return 0;
else
if((1 << in_mant) & mantissa)
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa >>= 1 - exponent;
exponent = 0;
out_exponent = 1; // denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
// above range: quantize to maximum possible float of the same sign
else if(exponent > max_exp)
else
{
if((1 << (in_mant + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa >>= (in_mant - out_mant);
if(out_exponent > max_exp)
{
if(clip)
{
mantissa = (1 << out_mant) - 1;
exponent = max_exp;
mantissa = (1 << out_mant) - 1;
out_exponent = max_exp;
}
else
{
......@@ -125,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
}
// check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0)
if(out_exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
}
template <typename X, typename Y, bool negative_zero_nan>
......@@ -194,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent++;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent--;
}
int sh = 1 + clz(mantissa) - (32 - in_mant);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << in_mant) - 1);
}
exponent += exp_low_cutoff - 1;
......
......@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
#endif
#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
c = __builtin_amdgcn_sudot4(true, bit_cast<int32_t>(a), true, bit_cast<int32_t>(b), c, false);
#else
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
......
......@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return min(max(x, lowerbound), upperbound);
}
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
static inline __host__ float exp(float x) { return std::expf(x); }
static inline __host__ double exp(double x) { return std::exp(x); }
// greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{
......
......@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck {
namespace math {
......@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ half_t tanh(half_t x)
template <typename T>
inline __host__ T tanh(T x)
{
return static_cast<half_t>(std::tanh(static_cast<float>(x)));
return ck::type_convert<T>(std::tanhf(ck::type_convert<float>(x)));
};
static inline __host__ float tanh(float x) { return std::tanh(x); };
template <>
inline __host__ float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
inline __host__ double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
inline __host__ T exp(T x)
{
return ck::type_convert<T>(std::expf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float exp<float>(float x)
{
return std::expf(x);
}
static inline __host__ double tanh(double x) { return std::tanh(x); };
template <>
inline __host__ double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
inline __host__ T log(T x)
{
return ck::type_convert<T>(std::logf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float log<float>(float x)
{
return std::logf(x);
}
template <>
inline __host__ double log<double>(double x)
{
return std::log(x);
}
template <typename T>
inline __host__ T pow(T x, T gamma)
{
return ck::type_convert<T>(
std::powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
}
template <>
inline __host__ float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
inline __host__ double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
inline __host__ T expm1(T x)
{
return ck::type_convert<T>(std::expm1f(ck::type_convert<float>(x)));
}
template <>
inline __host__ float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
inline __host__ double expm1<double>(double x)
{
return std::expm1(x);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
......@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ half_t tanh(half_t x)
template <typename T>
inline __device__ T tanh(T x)
{
return ck::type_convert<T>(::tanhf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float tanh<float>(float x)
{
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
return ::tanhf(x);
};
static inline __device__ float tanh(float x) { return ::tanhf(x); };
template <>
inline __device__ double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
inline __device__ T exp(T x)
{
return ck::type_convert<T>(__expf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t exp<half_t>(half_t x)
{
return hexp(x);
};
template <>
inline __device__ float exp<float>(float x)
{
return __expf(x);
};
static inline __device__ double tanh(double x) { return ::tanh(x); };
template <>
inline __device__ double exp<double>(double x)
{
return exp(x);
};
template <typename T>
inline __device__ T log(T x)
{
return ck::type_convert<T>(__logf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t log<half_t>(half_t x)
{
return hlog(x);
};
template <>
inline __device__ float log<float>(float x)
{
return __logf(x);
};
template <>
inline __device__ double log<double>(double x)
{
return log(x);
};
template <typename T>
inline __device__ T pow(T x, T gamma)
{
return ck::type_convert<T>(powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
};
template <>
inline __device__ float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
template <>
inline __device__ double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
inline __device__ T expm1(T x)
{
return ck::type_convert<T>(expm1f(ck::type_convert<float>(x)));
};
template <>
inline __device__ float expm1<float>(float x)
{
return expm1f(x);
};
template <>
inline __device__ double expm1<double>(double x)
{
return expm1(x);
};
} // namespace math
} // namespace ck
......@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace ck {
......
......@@ -100,6 +100,8 @@ template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
......@@ -138,6 +140,36 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
const vector_type<float, 2> f32x2_v(x);
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
f32x2_v.template AsType<float>()[Number<1>{}]);
return bit_cast<half2_t>(y);
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
......@@ -145,7 +177,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -153,8 +185,6 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<f8_t>(type_convert<float>(x));
#endif
}
......@@ -165,11 +195,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif
}
......@@ -223,7 +251,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<bf8_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -231,8 +259,6 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<bf8_t>(type_convert<float>(x));
#endif
}
......@@ -243,11 +269,9 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif
}
......@@ -347,7 +371,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
......@@ -356,8 +380,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return f8_convert_sr<f8_t>(type_convert<float>(x));
#endif
}
......@@ -396,7 +418,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
......@@ -406,8 +428,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return f8_convert_sr<bf8_t>(type_convert<float>(x));
#endif
}
......
......@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for column to image.
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
......@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
float Run(const Argument& arg)
{
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 2))
arg.input_.GetNumOfDimension() == 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Wo + wo;
......@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{
float v_in = ck::type_convert<float>(arg.input_(row, column));
float v_out = ck::type_convert<float>(arg.output_(0, n, c, wi));
arg.output_(0, n, c, wi) =
float v_in =
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi));
arg.output_(g, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
......@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
......@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
......@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[4])
{
float v_in =
ck::type_convert<float>(arg.input_(row, column));
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, hi, wi));
arg.output_(0, n, c, hi, wi) =
arg.output_(g, n, c, hi, wi));
arg.output_(g, n, c, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
......@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
......@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o)
{
for(index_t ho = 0; ho < Ho; ++ho)
......@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[5])
{
float v_in = ck::type_convert<float>(
arg.input_(row, column));
arg.input_(g, row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, di, hi, wi));
arg.output_(0, n, c, di, hi, wi) =
arg.output_(g, n, c, di, hi, wi));
arg.output_(g, n, c, di, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
......@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
......@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{
return false;
}
......
......@@ -23,6 +23,7 @@ template <ck::index_t NumDimM,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ComputeDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
......@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
for(ck::index_t k1 = 0; k1 < K1; ++k1)
{
// Simulate the possible casting when ComputeDataType is different than the
// A/B data types
ComputeDataType v_a_compute_input =
ck::type_convert<ComputeDataType>(arg.a_ms_ks_(m0, m1, k0, k1));
ComputeDataType v_b_compute_input =
ck::type_convert<ComputeDataType>(arg.b_ns_ks_(n0, n1, k0, k1));
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
arg.a_element_op_(v_a, ck::type_convert<AccDataType>(v_a_compute_input));
arg.b_element_op_(v_b, ck::type_convert<AccDataType>(v_b_compute_input));
v_acc += v_a * v_b;
}
}
arg.c_ms_ns_(m0, m1, n0, n1) = v_acc;
arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert<CDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ms_ns,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment