Commit ca8ba252 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent edc08fe6
......@@ -305,7 +305,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
"1 for now\n");
using Float4 = vector_type<float, 4>::MemoryType;
using Float4 = vector_type<float, 4>::type;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
......
......@@ -175,7 +175,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// copy data
vector_type<DstData, DstScalarPerVector> dst_vector;
using dst_vector_t = typename vector_type<DstData, DstScalarPerVector>::MemoryType;
using dst_vector_t = typename vector_type<DstData, DstScalarPerVector>::type;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset =
......@@ -504,7 +504,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
vector_type<SrcData, SrcScalarPerVector> src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::MemoryType;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
......@@ -838,7 +838,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// copy data
vector_type<SrcData, SrcScalarPerVector> src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::MemoryType;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
......@@ -1031,7 +1031,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
dst_vector.Scalars()(i) = buffer_[Number<buffer_offset>{}];
});
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::MemoryType;
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::type;
*reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
dst_vector.Vector();
......
......@@ -39,7 +39,7 @@ struct ThreadwiseMatrixSliceCopy
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
using vector_t = typename vector_type<Data, DataPerAccess>::type;
for(index_t i = 0; i < NSliceRow; ++i)
{
......
......@@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
using vector_t = typename vector_type<Data, DataPerAccess>::type;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
......
......@@ -91,11 +91,10 @@ __llvm_amdgcn_buffer_atomic_add_f32(float vdata,
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ typename vector_type<T, VectorSize>::MemoryType
amd_buffer_load(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_elemenst_space);
__device__ typename vector_type<T, VectorSize>::type amd_buffer_load(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_elemenst_space);
// buffer_store requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
......
......@@ -60,7 +60,7 @@ __llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ typename vector_type<T, VectorSize>::MemoryType
__device__ typename vector_type<T, VectorSize>::type
amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
......@@ -71,12 +71,11 @@ amd_buffer_load_v2(const T* p_src_wave,
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ void
amd_buffer_store_v2(const typename vector_type<T, VectorSize>::MemoryType src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_data_range);
__device__ void amd_buffer_store_v2(const typename vector_type<T, VectorSize>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_data_range);
template <>
__device__ float amd_buffer_load_v2<float, 1>(const float* p_src_wave,
......
......@@ -175,7 +175,7 @@ struct vector_type;
template <typename T>
struct vector_type<T, 1>
{
using MemoryType = T;
using type = T;
union
{
......@@ -206,7 +206,7 @@ struct vector_type<T, 2>
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
using MemoryType = d2_t;
using type = d2_t;
union
{
......@@ -243,7 +243,7 @@ struct vector_type<T, 4>
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
using MemoryType = d4_t;
using type = d4_t;
union
{
......@@ -286,7 +286,7 @@ struct vector_type<T, 8>
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
using MemoryType = d8_t;
using type = d8_t;
union
{
......
......@@ -32,16 +32,16 @@ struct vector_type
typedef struct
{
T scalar[N];
} MemoryType;
} type;
};
template <>
struct vector_type<float, 1>
{
using MemoryType = float;
using type = float;
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 1, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
......@@ -51,22 +51,22 @@ struct vector_type<float, 1>
template <>
struct vector_type<float, 2>
{
using MemoryType = float2_t;
using type = float2_t;
union DataType
{
MemoryType vector;
type vector;
float scalar[2];
};
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 2, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
}
__host__ __device__ static MemoryType Pack(float s0, float s1)
__host__ __device__ static type Pack(float s0, float s1)
{
DataType data;
data.scalar[0] = s0;
......@@ -78,12 +78,12 @@ struct vector_type<float, 2>
template <>
struct vector_type<float, 4>
{
using MemoryType = float4_t;
using type = float4_t;
__host__ __device__ static constexpr index_t GetSize() { return 4; }
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
{
static_assert(I < 4, "wrong");
*(reinterpret_cast<float*>(&v) + I) = s;
......@@ -93,10 +93,10 @@ struct vector_type<float, 4>
template <>
struct vector_type<half_t, 1>
{
using MemoryType = half_t;
using type = half_t;
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
__host__ __device__ static void SetScalar(type& v, half_t s, Number<I>)
{
static_assert(I < 1, "wrong");
*(reinterpret_cast<half_t*>(&v) + I) = s;
......@@ -106,22 +106,22 @@ struct vector_type<half_t, 1>
template <>
struct vector_type<half_t, 2>
{
using MemoryType = half2_t;
using type = half2_t;
union DataType
{
MemoryType vector;
type vector;
half_t scalar[2];
};
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
__host__ __device__ static void SetScalar(type& v, half_t s, Number<I>)
{
static_assert(I < 2, "wrong");
*(reinterpret_cast<half_t*>(&v) + I) = s;
}
__host__ __device__ static MemoryType Pack(half_t s0, half_t s1)
__host__ __device__ static type Pack(half_t s0, half_t s1)
{
DataType data;
data.scalar[0] = s0;
......
......@@ -44,7 +44,7 @@ __device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
template <typename T, index_t DataPerAccess>
struct SetData
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
using vector_t = typename vector_type<T, DataPerAccess>::type;
// This version is only for compatibility, don't use this version if possible
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
......@@ -122,7 +122,7 @@ struct SetData
template <typename T, index_t DataPerAccess>
struct AtomicAddData
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
using vector_t = typename vector_type<T, DataPerAccess>::type;
// This version is only for compatibility, don't use this version if possible
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
......
......@@ -37,7 +37,7 @@ __device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
template <typename T, index_t DataPerAccess>
struct SetData
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
using vector_t = typename vector_type<T, DataPerAccess>::type;
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
......@@ -50,7 +50,7 @@ struct SetData
template <typename T, index_t DataPerAccess>
struct AtomicAddData
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
using vector_t = typename vector_type<T, DataPerAccess>::type;
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment