Commit 712babe4 authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent 03f7892a
......@@ -546,6 +546,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
auto a_thread_buf = make_dynamic_buffer<FloatA>(p_a_thread);
auto b_thread_buf = make_dynamic_buffer<FloatB>(p_b_thread);
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
......@@ -559,7 +562,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_a_block,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BlockMatrixB{},
......@@ -567,7 +570,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_b_block,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BlockMatrixB{},
......@@ -575,7 +578,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_b_block,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread);
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(BlockMatrixA{},
......@@ -583,7 +586,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_a_block,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread);
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
......@@ -602,7 +605,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_a_block,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
......@@ -616,7 +619,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_b_block,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
......@@ -631,7 +634,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_b_block,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread);
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(BlockMatrixA{},
......@@ -639,7 +642,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_a_block,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread);
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
......
......@@ -52,7 +52,7 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
}
#endif
#if 1
#if 0
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......
......@@ -1362,13 +1362,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
"wrong! SrcDesc and DstDesc need to known at compile-time");
}
template <typename SrcRefToOriginDisplacement, typename DstRefToOriginDisplacement>
template <typename SrcRefToOriginDisplacement,
typename DstRefToOriginDisplacement,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcData* p_src,
const DstDesc&,
const DstRefToOriginDisplacement&,
DstData* p_dst) const
DstBuffer dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
......@@ -1450,8 +1452,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
move_dynamic_tensor_coordinate(
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
// copy data from src into buffer
StaticBuffer<SrcData, SrcScalarPerVector> src_buf;
// copy data from src_buf into src_tmp_buffer
auto src_tmp_buf = make_static_buffer<SrcData>(Number<SrcScalarPerVector>{});
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
......@@ -1459,18 +1461,28 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
src_buf.template AsType<src_vector_t>()(Number<0>{}) =
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid
? *reinterpret_cast<const src_vector_t*>(&p_src[src_data_coord.GetOffset()])
: src_vector_t{0};
// copy data from buffer into dst
// copy data from src_tmp_buf to dst_tmp_buf (data cast data from SrcData to DstData)
auto dst_tmp_buf = make_static_buffer<DstData>(Number<SrcScalarPerVector>{});
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_buf.template AsType<DstData>()(i) =
static_cast<DstData>(src_tmp_buf.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_buf into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
to_multi_index(dst_ref_to_origin_disp_idx) + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_buf.template AsType<SrcData>()[i];
dst_buf.template AsType<DstData>()(Number<dst_offset>{}) =
dst_tmp_buf.template AsType<DstData>()[i];
});
});
}
......
......@@ -20,7 +20,7 @@ struct StaticBuffer : public vector_type<ScalarType, N>
template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
using scalar_t = scalar_type<T>;
using scalar_t = typename scalar_type<T>::type;
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return StaticBuffer<scalar_t, N * scalar_per_vector>{};
......@@ -51,7 +51,7 @@ struct DynamicBuffer
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value,
bool>::type = false>
__host__ __device__ constexpr const auto& AsType() const
__host__ __device__ constexpr const auto AsType() const
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
......@@ -61,7 +61,7 @@ struct DynamicBuffer
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value,
bool>::type = false>
__host__ __device__ constexpr auto& AsType()
__host__ __device__ constexpr auto AsType()
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
......@@ -70,7 +70,7 @@ struct DynamicBuffer
template <typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{
using scalar_t = scalar_type<T>;
using scalar_t = typename scalar_type<T>::type;
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return DynamicBuffer<scalar_t>{p};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment