Commit 821ec5ae authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed namming

parent aaf3d81d
......@@ -103,10 +103,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
p_c_thread.At(Number<64>{})(Number<0>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block, p_b_block, p_c_thread.At(Number<64>{})[Number<0>{}]);
p_c_thread.At(Number<64>{})(Number<1>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block + MPerXdlops, p_b_block, p_c_thread.At(Number<64>{})[Number<1>{}]);
p_c_thread.GetVector(Number<64>{})(Number<0>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block, p_b_block, p_c_thread.GetVector(Number<64>{})[Number<0>{}]);
p_c_thread.GetVector(Number<64>{})(Number<1>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block + MPerXdlops, p_b_block, p_c_thread.GetVector(Number<64>{})[Number<1>{}]);
return p_c_thread;
}
......
......@@ -326,7 +326,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global))
.Store(c_thread_vec.At(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global);
.Store(c_thread_vec.GetVector(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global);
});
}
}
......
......@@ -194,7 +194,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
thread_buff.At(Number<SrcDataPerRead>{})(Number<buff_off>{}) = src_buff;
thread_buff.GetVector(Number<SrcDataPerRead>{})(Number<buff_off>{}) = src_buff;
});
}
......@@ -222,7 +222,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff = thread_buff.At(Number<DstDataPerWrite>{})[Number<buff_off>{}];
auto src_buff =
thread_buff.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
......@@ -254,7 +255,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff = src.At(Number<DstDataPerWrite>{})[Number<buff_off>{}];
auto src_buff = src.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
......
......@@ -39,22 +39,22 @@ union float_vec4_t
__host__ __device__ constexpr float_vec4_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
__host__ __device__ auto& GetVector(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
__host__ __device__ auto& GetVector(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<2>)
__host__ __device__ auto& GetVector(Number<2>)
{
return s2;
}
template <>
__host__ __device__ auto& At(Number<4>)
__host__ __device__ auto& GetVector(Number<4>)
{
return s4;
}
......@@ -69,28 +69,28 @@ union float_vec8_t
__host__ __device__ constexpr float_vec8_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
__host__ __device__ auto& GetVector(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
__host__ __device__ auto& GetVector(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<2>)
__host__ __device__ auto& GetVector(Number<2>)
{
return s2;
}
template <>
__host__ __device__ auto& At(Number<4>)
__host__ __device__ auto& GetVector(Number<4>)
{
return s4;
}
template <>
__host__ __device__ auto& At(Number<8>)
__host__ __device__ auto& GetVector(Number<8>)
{
return s8;
}
......@@ -106,34 +106,34 @@ union float_vec16_t
__host__ __device__ constexpr float_vec16_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
__host__ __device__ auto& GetVector(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
__host__ __device__ auto& GetVector(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<2>)
__host__ __device__ auto& GetVector(Number<2>)
{
return s2;
}
template <>
__host__ __device__ auto& At(Number<4>)
__host__ __device__ auto& GetVector(Number<4>)
{
return s4;
}
template <>
__host__ __device__ auto& At(Number<8>)
__host__ __device__ auto& GetVector(Number<8>)
{
return s8;
}
template <>
__host__ __device__ auto& At(Number<16>)
__host__ __device__ auto& GetVector(Number<16>)
{
return s16;
}
......@@ -150,34 +150,34 @@ union float_vec32_t
__host__ __device__ constexpr float_vec32_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
__host__ __device__ auto& GetVector(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
__host__ __device__ auto& GetVector(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<2>)
__host__ __device__ auto& GetVector(Number<2>)
{
return s2;
}
template <>
__host__ __device__ auto& At(Number<4>)
__host__ __device__ auto& GetVector(Number<4>)
{
return s4;
}
template <>
__host__ __device__ auto& At(Number<8>)
__host__ __device__ auto& GetVector(Number<8>)
{
return s8;
}
template <>
__host__ __device__ auto& At(Number<16>)
__host__ __device__ auto& GetVector(Number<16>)
{
return s16;
}
......@@ -192,16 +192,16 @@ union float_vec64_t
__host__ __device__ constexpr float_vec64_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
__host__ __device__ auto& GetVector(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
__host__ __device__ auto& GetVector(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<32>)
__host__ __device__ auto& GetVector(Number<32>)
{
return s32;
}
......@@ -217,28 +217,28 @@ union float_vec128_t
__host__ __device__ constexpr float_vec128_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
__host__ __device__ auto& GetVector(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
__host__ __device__ auto& GetVector(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<16>)
__host__ __device__ auto& GetVector(Number<16>)
{
return s16;
}
template <>
__host__ __device__ auto& At(Number<32>)
__host__ __device__ auto& GetVector(Number<32>)
{
return s32;
}
template <>
__host__ __device__ auto& At(Number<64>)
__host__ __device__ auto& GetVector(Number<64>)
{
return s64;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment