Commit 5f0c56d0 authored by Chao Liu's avatar Chao Liu
Browse files

refactor vector_type

parent 20fa988f
......@@ -186,7 +186,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector);
dst_vector(i) = p_src[Number<src_offset>{}];
dst_vector.Scalars()(i) = p_src[Number<src_offset>{}];
});
amd_buffer_store_v2<DstData, DstScalarPerVector>(
......@@ -837,7 +837,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
buffer_(Number<buffer_offset>{}) = src_vector[i];
buffer_(Number<buffer_offset>{}) = src_vector.Scalars()[i];
});
constexpr auto move_on_dim = [&]() constexpr
......@@ -995,7 +995,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector(i) = buffer_[Number<buffer_offset>{}];
dst_vector.Scalars()(i) = buffer_[Number<buffer_offset>{}];
});
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::MemoryType;
......
......@@ -188,28 +188,28 @@ __device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
vector_type<float, 8> vector;
vector_type<float, 8> tmp;
vector.Set(Number<4>{}, Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0);
vector.Set(Number<4>{}, Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data,
src_addr_shift + src_thread_addr_offset + 4 * sizeof(float),
0,
0);
return vector.Get(Number<8>{}, Number<0>{});
return tmp.Vector();
#else
vector_type<float, 8> vector;
vector_type<float, 8> tmp;
vector.Set(Number<4>{}, Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0);
vector.Set(Number<4>{}, Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource.data, src_thread_addr_offset + 4 * sizeof(float), 0, 0);
return src_thread_data_valid ? vector.Get(Number<8>{}, Number<0>{}) : float8_t(0);
return src_thread_data_valid ? tmp.Vector() : float8_t(0);
#endif
}
......
......@@ -177,47 +177,27 @@ struct vector_type<T, 1>
{
using MemoryType = T;
T data_;
union
{
T d1_;
StaticallyIndexedArray<T, 1> d1x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{T{0}} {}
__host__ __device__ static constexpr index_t Size() { return 1; }
__host__ __device__ constexpr const auto& Vector() const { return data_; }
__host__ __device__ constexpr auto& Vector() { return data_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<1>, Number<I>) const
{
static_assert(I == 0, "wrong!");
__host__ __device__ constexpr const auto& Vector() const { return data_.d1_; }
return data_;
}
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<1>, Number<I>)
{
static_assert(I == 0, "wrong!");
return data_;
}
__host__ __device__ constexpr auto& Vector() { return data_.d1_; }
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I>) const
{
static_assert(I == 0, "wrong!");
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x1_; }
return data_;
}
__host__ __device__ constexpr auto& Scalars() { return data_.d1x1_; }
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I>)
{
static_assert(I == 0, "wrong!");
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x1_; }
return data_;
}
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x1_; }
};
template <typename T>
......@@ -232,6 +212,7 @@ struct vector_type<T, 2>
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{d2_t{0}} {}
......@@ -242,53 +223,17 @@ struct vector_type<T, 2>
__host__ __device__ constexpr auto& Vector() { return data_.d2_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<1>, Number<I> i) const
{
static_assert(I >= 0 && I < 2, "wrong!");
return data_.d1x2_[i];
}
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<2>, Number<I>) const
{
static_assert(I == 0, "wrong!");
return data_.d2_;
}
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<1>, Number<I> i)
{
static_assert(I >= 0 && I < 2, "wrong!");
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; }
return data_.d1x2_(i);
}
__host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; }
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<2>, Number<I>)
{
static_assert(I == 0, "wrong!");
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; }
return data_.d2_;
}
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; }
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I>) const
{
static_assert(I >= 0 && I < 2, "wrong!");
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; }
return data_.d1x2_[Number<I>{}];
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I>)
{
static_assert(I >= 0 && I < 2, "wrong!");
return data_.d1x2_(Number<I>{});
}
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; }
};
template <typename T>
......@@ -305,6 +250,7 @@ struct vector_type<T, 4>
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{d4_t{0}} {}
......@@ -315,69 +261,21 @@ struct vector_type<T, 4>
__host__ __device__ constexpr auto& Vector() { return data_.d4_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<1>, Number<I> i) const
{
static_assert(I >= 0 && I < 4, "wrong!");
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; }
return data_.d1x4_[i];
}
__host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<2>, Number<I> i) const
{
static_assert(I >= 0 && I < 2, "wrong!");
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; }
return data_.d2x2_[i];
}
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<4>, Number<I>) const
{
static_assert(I == 0, "wrong!");
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; }
return data_.d4_;
}
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; }
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<1>, Number<I> i)
{
static_assert(I >= 0 && I < 4, "wrong!");
return data_.d1x4_(i);
}
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<2>, Number<I> i)
{
static_assert(I >= 0 && I < 3, "wrong!");
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; }
return data_.d2x2_(i);
}
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<4>, Number<I>)
{
static_assert(I == 0, "wrong!");
return data_.d4_;
}
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I>) const
{
static_assert(I >= 0 && I < 4, "wrong!");
return data_.d1x4_[Number<I>{}];
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I>)
{
static_assert(I >= 0 && I < 4, "wrong!");
return data_.d1x4_(Number<I>{});
}
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
};
template <typename T>
......@@ -396,6 +294,7 @@ struct vector_type<T, 8>
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{d8_t{0}} {}
......@@ -406,85 +305,25 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vector() { return data_.d8_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<1>, Number<I> i) const
{
static_assert(I >= 0 && I < 8, "wrong!");
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; }
return data_.d1x8_[i];
}
__host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<2>, Number<I> i) const
{
static_assert(I >= 0 && I < 4, "wrong!");
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; }
return data_.d2x4_[i];
}
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<4>, Number<I> i) const
{
static_assert(I >= 0 && I < 2, "wrong!");
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; }
return data_.d4x2_[i];
}
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; }
template <index_t I>
__host__ __device__ constexpr const auto& Get(Number<8>, Number<I>) const
{
static_assert(I == 0, "wrong!");
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; }
return data_.d8_;
}
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; }
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<1>, Number<I> i)
{
static_assert(I >= 0 && I < 8, "wrong!");
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; }
return data_.d1x8_(i);
}
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<2>, Number<I> i)
{
static_assert(I >= 0 && I < 4, "wrong!");
return data_.d2x4_(i);
}
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<4>, Number<I> i)
{
static_assert(I >= 0 && I < 2, "wrong!");
return data_.d4x2_(i);
}
template <index_t I>
__host__ __device__ constexpr auto& Set(Number<8>, Number<I> i)
{
static_assert(I == 0, "wrong!");
return data_.d8_;
}
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I>) const
{
static_assert(I >= 0 && I < 8, "wrong!");
return data_.d1x8_[Number<I>{}];
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I>)
{
static_assert(I >= 0 && I < 8, "wrong!");
return data_.d1x8_(Number<I>{});
}
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
};
// data type conversion
......
......@@ -25,6 +25,8 @@ void device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
RightPads,
ck::index_t nrepeat)
{
std::cout << "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw" << std::endl;
using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
......
......@@ -24,6 +24,8 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
InRightPads,
ck::index_t nrepeat)
{
std::cout << "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" << std::endl;
using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
......
......@@ -110,39 +110,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
#if 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
......
......@@ -616,7 +616,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 0
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment