Commit aeb05cc4 authored by Chao Liu's avatar Chao Liu
Browse files

use vector type for holding C thread matrix data, but it cause register over-allocation

parent d990eff6
......@@ -2,6 +2,7 @@
#define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp"
namespace ck {
......
......@@ -5,9 +5,10 @@
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_gemm_v2.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
......@@ -730,12 +731,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
#if 0
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
auto c_thread_buf = make_dynamic_buffer<FloatAcc>(p_c_thread);
// zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
#else
auto c_thread_buf =
make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0m1_n0n1_thread_desc),
Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{}
.Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0});
#endif
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
......@@ -916,7 +927,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_thread_buf,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
......
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// Assume:
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 4. use #-iterator
template <typename Data,
typename Desc,
typename SliceLengths,
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceSet_v1
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
template <typename OriginIdx, typename Buffer>
__device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const
{
static_assert(Desc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value,
"wrong! OriginIdx need to be known at compile-time");
// Desc is known at compile-time
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{};
// OriginIdx is known at compile-time
constexpr auto origin_idx = to_multi_index(OriginIdx{});
static_ford<SliceLengths>{}([&](auto access_idx) {
constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx);
constexpr bool is_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
constexpr index_t offset = coord.GetOffset();
if constexpr(is_valid)
{
buf.template AsType<Data>()(Number<offset>{}) = initial_value;
}
});
}
};
} // namespace ck
#endif
......@@ -38,11 +38,14 @@ struct lambda_scalar_step_in_vector
} // namespace detail
// Assume:
// 1. src_desc is known at compile-time
// 2. dst_desc is not known at compile-time
// 3. src_slice_origin_idx is known at compile-time and it's 0
// 4. dst_slice_origin_idx is not-known at compile time
// TODO: support non-zero src_slice_oring_idx
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is StaticBuffer
// 3. SrcSliceOrginIdx is known at compile-time
// 2. dst:
// 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time
template <typename SrcData,
typename DstData,
typename SrcDesc,
......@@ -80,10 +83,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcSliceOriginIdx, typename DstIteratorHacks>
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstIteratorHacks>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcData* p_src,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstData* p_dst,
const DstIteratorHacks& dst_iterator_hacks)
......@@ -97,7 +100,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto src_slice_origin_idx = SrcSliceOriginIdx{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -189,12 +192,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset =
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector);
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(p_src[Number<src_offset>{}]);
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>{}(
src_buf.template AsType<SrcData>()[Number<src_offset>{}]);
});
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
......@@ -1489,7 +1491,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_buf.template AsType<DstData>()(i) =
static_cast<DstData>(src_tmp_buf.template AsType<SrcData>()[i]);
type_convert<DstData>{}(src_tmp_buf.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_buf into dst_buf
......
......@@ -403,32 +403,249 @@ struct vector_type<T, 16>
}
};
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
};
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// data type conversion
template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment