Commit aeb05cc4 authored by Chao Liu's avatar Chao Liu
Browse files

use vector type for holding C thread matrix data, but it cause register over-allocation

parent d990eff6
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CK_BLOCKWISE_GEMM_V2_HPP #define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp" #include "threadwise_gemm_v2.hpp"
namespace ck { namespace ck {
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
#include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_gemm_v2.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -730,12 +731,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -730,12 +731,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
#if 0
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
auto c_thread_buf = make_dynamic_buffer<FloatAcc>(p_c_thread); auto c_thread_buf = make_dynamic_buffer<FloatAcc>(p_c_thread);
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
#else
auto c_thread_buf =
make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0m1_n0n1_thread_desc),
Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{}
.Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0});
#endif
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -916,7 +927,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -916,7 +927,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
n_thread_data_on_global % N1)) n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc, .Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_c_thread, c_thread_buf,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_n0_n1_global_tensor_iterator_hacks);
......
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// Assume:
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 4. use #-iterator
template <typename Data,
typename Desc,
typename SliceLengths,
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceSet_v1
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
template <typename OriginIdx, typename Buffer>
__device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const
{
static_assert(Desc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value,
"wrong! OriginIdx need to be known at compile-time");
// Desc is known at compile-time
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{};
// OriginIdx is known at compile-time
constexpr auto origin_idx = to_multi_index(OriginIdx{});
static_ford<SliceLengths>{}([&](auto access_idx) {
constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx);
constexpr bool is_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
constexpr index_t offset = coord.GetOffset();
if constexpr(is_valid)
{
buf.template AsType<Data>()(Number<offset>{}) = initial_value;
}
});
}
};
} // namespace ck
#endif
...@@ -38,11 +38,14 @@ struct lambda_scalar_step_in_vector ...@@ -38,11 +38,14 @@ struct lambda_scalar_step_in_vector
} // namespace detail } // namespace detail
// Assume: // Assume:
// 1. src_desc is known at compile-time // 1. src:
// 2. dst_desc is not known at compile-time // 1. SrcDesc is known at compile-time
// 3. src_slice_origin_idx is known at compile-time and it's 0 // 2. SrcBuffer is StaticBuffer
// 4. dst_slice_origin_idx is not-known at compile time // 3. SrcSliceOrginIdx is known at compile-time
// TODO: support non-zero src_slice_oring_idx // 2. dst:
// 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
...@@ -80,10 +83,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -80,10 +83,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcSliceOriginIdx, typename DstIteratorHacks> template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstIteratorHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcData* p_src, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstData* p_dst, DstData* p_dst,
const DstIteratorHacks& dst_iterator_hacks) const DstIteratorHacks& dst_iterator_hacks)
...@@ -97,7 +100,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -97,7 +100,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// SrcDesc and src_slice_origin_idx are known at compile-time // SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto src_slice_origin_idx = SrcSliceOriginIdx{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -189,12 +192,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -189,12 +192,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
typename vector_type_maker<DstData, DstScalarPerVector>::type::type; typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = constexpr index_t src_offset = src_desc.CalculateOffset(
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
i * dst_scalar_step_in_vector);
dst_vector.template AsType<DstData>()(i) = dst_vector.template AsType<DstData>()(i) = type_convert<DstData>{}(
type_convert<DstData>{}(p_src[Number<src_offset>{}]); src_buf.template AsType<SrcData>()[Number<src_offset>{}]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
...@@ -1489,7 +1491,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1489,7 +1491,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile // TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_buf.template AsType<DstData>()(i) = dst_tmp_buf.template AsType<DstData>()(i) =
static_cast<DstData>(src_tmp_buf.template AsType<SrcData>()[i]); type_convert<DstData>{}(src_tmp_buf.template AsType<SrcData>()[i]);
}); });
// copy data from dst_tmp_buf into dst_buf // copy data from dst_tmp_buf into dst_buf
......
...@@ -403,32 +403,249 @@ struct vector_type<T, 16> ...@@ -403,32 +403,249 @@ struct vector_type<T, 16>
} }
}; };
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
};
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type; using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16 // fp16
using half2_t = typename vector_type<half_t, 2>::type; using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type; using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type; using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type; using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16 // bfp16
using ushort2_t = typename vector_type<ushort, 2>::type; using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type; using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type; using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
// i32 // i32
using int32x2_t = typename vector_type<int32_t, 2>::type; using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type; using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type; using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8 // i8
using int8x2_t = typename vector_type<int8_t, 2>::type; using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type; using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type; using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// data type conversion // data type conversion
template <typename T> template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment