Commit 821ec5ae authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed namming

parent aaf3d81d
...@@ -103,10 +103,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -103,10 +103,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
FloatC p_c_thread) FloatC p_c_thread)
{ {
p_c_thread.At(Number<64>{})(Number<0>{}) = XdlopsGemm.template Run<M, N, K>( p_c_thread.GetVector(Number<64>{})(Number<0>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block, p_b_block, p_c_thread.At(Number<64>{})[Number<0>{}]); p_a_block, p_b_block, p_c_thread.GetVector(Number<64>{})[Number<0>{}]);
p_c_thread.At(Number<64>{})(Number<1>{}) = XdlopsGemm.template Run<M, N, K>( p_c_thread.GetVector(Number<64>{})(Number<1>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block + MPerXdlops, p_b_block, p_c_thread.At(Number<64>{})[Number<1>{}]); p_a_block + MPerXdlops, p_b_block, p_c_thread.GetVector(Number<64>{})[Number<1>{}]);
return p_c_thread; return p_c_thread;
} }
......
...@@ -326,7 +326,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -326,7 +326,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global % (M2 * M1) / M2, m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2, m_thread_data_on_global % M2,
n_thread_data_on_global)) n_thread_data_on_global))
.Store(c_thread_vec.At(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global); .Store(c_thread_vec.GetVector(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global);
}); });
} }
} }
......
...@@ -194,7 +194,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -194,7 +194,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) / ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size; long_vector_size;
thread_buff.At(Number<SrcDataPerRead>{})(Number<buff_off>{}) = src_buff; thread_buff.GetVector(Number<SrcDataPerRead>{})(Number<buff_off>{}) = src_buff;
}); });
} }
...@@ -222,7 +222,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -222,7 +222,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) / ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size; long_vector_size;
auto src_buff = thread_buff.At(Number<DstDataPerWrite>{})[Number<buff_off>{}]; auto src_buff =
thread_buff.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id); const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
...@@ -254,7 +255,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -254,7 +255,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) / ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size; long_vector_size;
auto src_buff = src.At(Number<DstDataPerWrite>{})[Number<buff_off>{}]; auto src_buff = src.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id); const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
......
...@@ -39,22 +39,22 @@ union float_vec4_t ...@@ -39,22 +39,22 @@ union float_vec4_t
__host__ __device__ constexpr float_vec4_t() {} __host__ __device__ constexpr float_vec4_t() {}
template <index_t vs> template <index_t vs>
__host__ __device__ auto& At(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
template <> template <>
__host__ __device__ auto& At(Number<1>) __host__ __device__ auto& GetVector(Number<1>)
{ {
return s1; return s1;
} }
template <> template <>
__host__ __device__ auto& At(Number<2>) __host__ __device__ auto& GetVector(Number<2>)
{ {
return s2; return s2;
} }
template <> template <>
__host__ __device__ auto& At(Number<4>) __host__ __device__ auto& GetVector(Number<4>)
{ {
return s4; return s4;
} }
...@@ -69,28 +69,28 @@ union float_vec8_t ...@@ -69,28 +69,28 @@ union float_vec8_t
__host__ __device__ constexpr float_vec8_t() {} __host__ __device__ constexpr float_vec8_t() {}
template <index_t vs> template <index_t vs>
__host__ __device__ auto& At(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
template <> template <>
__host__ __device__ auto& At(Number<1>) __host__ __device__ auto& GetVector(Number<1>)
{ {
return s1; return s1;
} }
template <> template <>
__host__ __device__ auto& At(Number<2>) __host__ __device__ auto& GetVector(Number<2>)
{ {
return s2; return s2;
} }
template <> template <>
__host__ __device__ auto& At(Number<4>) __host__ __device__ auto& GetVector(Number<4>)
{ {
return s4; return s4;
} }
template <> template <>
__host__ __device__ auto& At(Number<8>) __host__ __device__ auto& GetVector(Number<8>)
{ {
return s8; return s8;
} }
...@@ -106,34 +106,34 @@ union float_vec16_t ...@@ -106,34 +106,34 @@ union float_vec16_t
__host__ __device__ constexpr float_vec16_t() {} __host__ __device__ constexpr float_vec16_t() {}
template <index_t vs> template <index_t vs>
__host__ __device__ auto& At(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
template <> template <>
__host__ __device__ auto& At(Number<1>) __host__ __device__ auto& GetVector(Number<1>)
{ {
return s1; return s1;
} }
template <> template <>
__host__ __device__ auto& At(Number<2>) __host__ __device__ auto& GetVector(Number<2>)
{ {
return s2; return s2;
} }
template <> template <>
__host__ __device__ auto& At(Number<4>) __host__ __device__ auto& GetVector(Number<4>)
{ {
return s4; return s4;
} }
template <> template <>
__host__ __device__ auto& At(Number<8>) __host__ __device__ auto& GetVector(Number<8>)
{ {
return s8; return s8;
} }
template <> template <>
__host__ __device__ auto& At(Number<16>) __host__ __device__ auto& GetVector(Number<16>)
{ {
return s16; return s16;
} }
...@@ -150,34 +150,34 @@ union float_vec32_t ...@@ -150,34 +150,34 @@ union float_vec32_t
__host__ __device__ constexpr float_vec32_t() {} __host__ __device__ constexpr float_vec32_t() {}
template <index_t vs> template <index_t vs>
__host__ __device__ auto& At(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
template <> template <>
__host__ __device__ auto& At(Number<1>) __host__ __device__ auto& GetVector(Number<1>)
{ {
return s1; return s1;
} }
template <> template <>
__host__ __device__ auto& At(Number<2>) __host__ __device__ auto& GetVector(Number<2>)
{ {
return s2; return s2;
} }
template <> template <>
__host__ __device__ auto& At(Number<4>) __host__ __device__ auto& GetVector(Number<4>)
{ {
return s4; return s4;
} }
template <> template <>
__host__ __device__ auto& At(Number<8>) __host__ __device__ auto& GetVector(Number<8>)
{ {
return s8; return s8;
} }
template <> template <>
__host__ __device__ auto& At(Number<16>) __host__ __device__ auto& GetVector(Number<16>)
{ {
return s16; return s16;
} }
...@@ -192,16 +192,16 @@ union float_vec64_t ...@@ -192,16 +192,16 @@ union float_vec64_t
__host__ __device__ constexpr float_vec64_t() {} __host__ __device__ constexpr float_vec64_t() {}
template <index_t vs> template <index_t vs>
__host__ __device__ auto& At(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
template <> template <>
__host__ __device__ auto& At(Number<1>) __host__ __device__ auto& GetVector(Number<1>)
{ {
return s1; return s1;
} }
template <> template <>
__host__ __device__ auto& At(Number<32>) __host__ __device__ auto& GetVector(Number<32>)
{ {
return s32; return s32;
} }
...@@ -217,28 +217,28 @@ union float_vec128_t ...@@ -217,28 +217,28 @@ union float_vec128_t
__host__ __device__ constexpr float_vec128_t() {} __host__ __device__ constexpr float_vec128_t() {}
template <index_t vs> template <index_t vs>
__host__ __device__ auto& At(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
template <> template <>
__host__ __device__ auto& At(Number<1>) __host__ __device__ auto& GetVector(Number<1>)
{ {
return s1; return s1;
} }
template <> template <>
__host__ __device__ auto& At(Number<16>) __host__ __device__ auto& GetVector(Number<16>)
{ {
return s16; return s16;
} }
template <> template <>
__host__ __device__ auto& At(Number<32>) __host__ __device__ auto& GetVector(Number<32>)
{ {
return s32; return s32;
} }
template <> template <>
__host__ __device__ auto& At(Number<64>) __host__ __device__ auto& GetVector(Number<64>)
{ {
return s64; return s64;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment