Commit dcee43fe authored by Chao Liu's avatar Chao Liu
Browse files

using tuple (instead of vector) for holding C thread matrix data to solve...

using tuple (instead of vector) for holding C thread matrix data to solve register over-allocation issue
parent aeb05cc4
...@@ -731,14 +731,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -731,14 +731,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
#if 0
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
auto c_thread_buf = make_dynamic_buffer<FloatAcc>(p_c_thread);
// zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
#else
auto c_thread_buf = auto c_thread_buf =
make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize()); make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize());
...@@ -746,7 +738,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -746,7 +738,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
decltype(c_m0m1_n0n1_thread_desc), decltype(c_m0m1_n0n1_thread_desc),
Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{} Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{}
.Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0});
#endif
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
......
...@@ -49,7 +49,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1 ...@@ -49,7 +49,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1
if constexpr(is_valid) if constexpr(is_valid)
{ {
buf.template AsType<Data>()(Number<offset>{}) = initial_value; buf(Number<offset>{}) = initial_value;
} }
}); });
} }
......
...@@ -98,6 +98,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -98,6 +98,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value, is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value,
"wrong! SrcSliceOrigin need to known at compile-time"); "wrong! SrcSliceOrigin need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value,
"wrong! SrcBuffer data type is wrong");
// SrcDesc and src_slice_origin_idx are known at compile-time // SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
...@@ -195,8 +199,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -195,8 +199,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>{}( dst_vector.template AsType<DstData>()(i) =
src_buf.template AsType<SrcData>()[Number<src_offset>{}]); type_convert<DstData>{}(src_buf[Number<src_offset>{}]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
...@@ -1315,7 +1319,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1315,7 +1319,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<SrcData, buffer_size_> buffer_; typename vector_type_maker<SrcData, buffer_size_>::type buffer_;
SrcCoord src_slice_origin_coord_; SrcCoord src_slice_origin_coord_;
DstCoord dst_slice_origin_coord_; DstCoord dst_slice_origin_coord_;
...@@ -1381,6 +1385,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1381,6 +1385,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time< static_assert(is_known_at_compile_time<
...@@ -1462,46 +1472,45 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1462,46 +1472,45 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
// copy data from src_buf into src_tmp_buffer // copy data from src_buf into src_tmp_buffer
auto src_tmp_buf = make_static_buffer<SrcData>(Number<SrcScalarPerVector>{}); vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = using src_vector_t = typename decltype(src_tmp_vector)::type;
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord); src_desc, src_data_coord);
#if 0 #if 0
// TODO: this is slooooooooow due to VGPR over-allocation // TODO: this is slooooooooow due to VGPR over-allocation
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset() / is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset() /
SrcScalarPerVector] SrcScalarPerVector]
: src_vector_t{0}; : src_vector_t{0};
#else #else
// TODO: this is workaround. this has normal performance but it's hacky // TODO: this is workaround. this has normal performance but it's hacky
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid is_src_valid
? *reinterpret_cast<const src_vector_t*>(&(reinterpret_cast<const SrcData*>( ? *reinterpret_cast<const src_vector_t*>(&(reinterpret_cast<const SrcData*>(
src_buf.p_scalar_)[src_data_coord.GetOffset()])) src_buf.p_data_)[src_data_coord.GetOffset()]))
: src_vector_t{0}; : src_vector_t{0};
#endif #endif
// copy data from src_tmp_buf to dst_tmp_buf (data cast data from SrcData to DstData) // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
auto dst_tmp_buf = make_static_buffer<DstData>(Number<SrcScalarPerVector>{}); // DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile // TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_buf.template AsType<DstData>()(i) = dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(src_tmp_buf.template AsType<SrcData>()[i]); type_convert<DstData>{}(src_tmp_vector.template AsType<SrcData>()[i]);
}); });
// copy data from dst_tmp_buf into dst_buf // copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
to_multi_index(dst_ref_to_origin_disp_idx) + data_to_origin_disp_idx + to_multi_index(dst_ref_to_origin_disp_idx) + data_to_origin_disp_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
dst_buf.template AsType<DstData>()(Number<dst_offset>{}) = dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
dst_tmp_buf.template AsType<DstData>()[i];
}); });
}); });
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
namespace ck { namespace ck {
#if 1
template <typename Float, typename Desc> template <typename Float, typename Desc>
__device__ void threadwise_matrix_set_zero_v2(Desc, Float* __restrict__ p_thread) __device__ void threadwise_matrix_set_zero_v2(Desc, Float* __restrict__ p_thread)
{ {
...@@ -167,6 +168,7 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -167,6 +168,7 @@ struct ThreadwiseGemm_km_kn_mn_v1
#endif #endif
} }
}; };
#endif
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data // Element of matrix can be vectorized data
...@@ -231,14 +233,12 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -231,14 +233,12 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n)); CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
amd_assembly_inner_product(a_buf.template AsType<FloatA>()[Number<a_offset>{}], amd_assembly_inner_product(a_buf[Number<a_offset>{}],
b_buf.template AsType<FloatB>()[Number<b_offset>{}], b_buf[Number<b_offset>{}],
c_buf.template AsType<FloatC>()(Number<c_offset>{})); c_buf(Number<c_offset>{}));
#else #else
c_buf.template AsType<FloatC>()(Number<c_offset>{}) += c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
inner_product_with_conversion<FloatC>{}( a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
a_buf.template AsType<FloatA>()[Number<a_offset>{}],
b_buf.template AsType<FloatB>()[Number<b_offset>{}]);
#endif #endif
}); });
}); });
......
...@@ -23,6 +23,20 @@ __device__ void amd_assembly_inner_product(const float& a, const float& b, float ...@@ -23,6 +23,20 @@ __device__ void amd_assembly_inner_product(const float& a, const float& b, float
#endif #endif
} }
__device__ void amd_assembly_inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
}
#if 0
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
...@@ -386,6 +400,7 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -386,6 +400,7 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c2, c2,
c3); c3);
} }
#endif
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_BUFFER_HPP #ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP #define CK_BUFFER_HPP
#include "float_type.hpp" #include "statically_indexed_array.hpp"
namespace ck { namespace ck {
template < template <typename T, index_t N>
typename ScalarType, struct StaticBuffer : public StaticallyIndexedArray<T, N>
index_t N,
typename std::enable_if<is_same<typename scalar_type<ScalarType>::type, ScalarType>::value,
bool>::type = false>
struct StaticBuffer : public vector_type<ScalarType, N>
{ {
using base = vector_type<ScalarType, N>; using type = T;
using base = StaticallyIndexedArray<T, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
...@@ -24,50 +21,46 @@ struct StaticBuffer : public vector_type<ScalarType, N> ...@@ -24,50 +21,46 @@ struct StaticBuffer : public vector_type<ScalarType, N>
template <typename T, index_t N> template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
using scalar_t = typename scalar_type<T>::type; return StaticBuffer<T, N>{};
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return StaticBuffer<scalar_t, N * scalar_per_vector>{};
} }
template < template <typename T>
typename ScalarType,
typename std::enable_if<is_same<typename scalar_type<ScalarType>::type, ScalarType>::value,
bool>::type = false>
struct DynamicBuffer struct DynamicBuffer
{ {
template <typename T> using type = T;
template <typename X>
struct PointerWrapper struct PointerWrapper
{ {
T* p_; X* p_;
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_[i]; } __host__ __device__ constexpr const X& operator[](index_t i) const { return p_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_[i]; } __host__ __device__ constexpr X& operator()(index_t i) { return p_[i]; }
}; };
ScalarType* p_scalar_; T* p_data_;
__host__ __device__ constexpr DynamicBuffer(ScalarType* p_scalar) : p_scalar_{p_scalar} {} __host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {}
template <typename X, template <typename X,
typename std::enable_if< typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr const auto AsType() const __host__ __device__ constexpr const auto AsType() const
{ {
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)}; return PointerWrapper<X>{reinterpret_cast<X*>(p_data_)};
} }
template <typename X, template <typename X,
typename std::enable_if< typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr auto AsType() __host__ __device__ constexpr auto AsType()
{ {
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)}; return PointerWrapper<X>{reinterpret_cast<X*>(p_data_)};
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
...@@ -78,10 +71,7 @@ struct DynamicBuffer ...@@ -78,10 +71,7 @@ struct DynamicBuffer
template <typename T> template <typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p) __host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{ {
using scalar_t = typename scalar_type<T>::type; return DynamicBuffer<T>{p};
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return DynamicBuffer<scalar_t>{p};
} }
} // namespace ck } // namespace ck
......
#ifndef CK_FLOAT_TYPE_AMD_HPP #ifndef CK_FLOAT_TYPE_AMD_HPP
#define CK_FLOAT_TYPE_AMD_HPP #define CK_FLOAT_TYPE_AMD_HPP
#include "statically_indexed_array.hpp"
namespace ck { namespace ck {
using half_t = _Float16; using half_t = _Float16;
...@@ -43,6 +45,15 @@ struct vector_type_maker<vector_type<T, N1>, N0> ...@@ -43,6 +45,15 @@ struct vector_type_maker<vector_type<T, N1>, N0>
using type = vector_type<T, N0 * N1>; using type = vector_type<T, N0 * N1>;
}; };
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
__host__ __device__ constexpr auto make_vector_type(Number<N>)
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type // scalar_type
template <typename TV> template <typename TV>
struct scalar_type; struct scalar_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment