Commit 841b1480 authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent e4790c25
......@@ -281,10 +281,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
#pragma unroll
// loop over rest of k
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)),
p_a_thread);
......@@ -324,7 +322,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
}
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
......
......@@ -265,7 +265,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
index_t b_block_data_begin = 0;
#if 1
if constexpr(HasMainKBlockLoop)
{
FloatAB* p_b_thread_even = p_b_thread_double;
......@@ -350,9 +349,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_thread_double,
p_c_thread);
}
#endif
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
......@@ -385,7 +382,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
}
// pass tensor descriptor by reference
......
......@@ -875,7 +875,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
buffer_(Number<buffer_offset>{}) = src_vector.template AsType<SrcData>()[i];
buffer_.template AsType<SrcData>()(Number<buffer_offset>{}) =
src_vector.template AsType<SrcData>()[i];
});
constexpr auto move_on_dim = [&]() constexpr
......@@ -1032,7 +1033,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.template AsType<DstData>()(i) = buffer_[Number<buffer_offset>{}];
dst_vector.template AsType<DstData>()(i) =
buffer_.template AsType<DstData>()[Number<buffer_offset>{}];
});
using DstVectorType =
......@@ -1297,7 +1299,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticallyIndexedArray<SrcData, buffer_size_> buffer_;
StaticBuffer<SrcData, buffer_size_> buffer_;
SrcCoord src_slice_origin_coord_;
DstCoord dst_slice_origin_coord_;
......
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "float_type.hpp"
namespace ck {
template <typename T, index_t N>
struct StaticBuffer : public vector_type_maker<T, N>::type
{
using base = typename vector_type_maker<T, N>::type;
__host__ __device__ constexpr StaticBuffer() : base{} {}
};
template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<T, N>{};
}
} // namespace ck
#endif
......@@ -7,6 +7,7 @@
#include "statically_indexed_array.hpp"
#include "container_element_picker.hpp"
#include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment