Commit 71ffba65 authored by Jane.Zhou's avatar Jane.Zhou
Browse files

use flag to indicate if a load access can be x4 vectorized

parent 7d09790a
...@@ -144,6 +144,13 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -144,6 +144,13 @@ struct BlockwiseGenericTensorSliceCopy_v4
{ {
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction); mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
} }
#if CK_VECTORIZE_FLAG
__device__ void SetVectorizeFlag()
{
mThreadwiseLoad.SetVectorizeFlag();
}
#endif
private: private:
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{})); using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{}));
......
...@@ -228,6 +228,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -228,6 +228,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.Run(p_a_global, p_a_block_double); a_blockwise_copy.Run(p_a_global, p_a_block_double);
#if CK_VECTORIZE_FLAG
b_blockwise_copy.SetVectorizeFlag();
#endif
b_blockwise_copy.Run(p_b_global, p_b_block_double); b_blockwise_copy.Run(p_b_global, p_b_block_double);
} }
...@@ -285,7 +288,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -285,7 +288,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_coordinate.hpp" #include "tensor_coordinate.hpp"
namespace ck { namespace ck {
// This threadwise copy allow vector access of src and dst. // This threadwise copy allow vector access of src and dst.
...@@ -36,6 +35,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -36,6 +35,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const Index& dst_slice_origin) const Index& dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin) : mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{ {
#if CK_VECTORIZE_FLAG
vectorize_flag = 0;
isVectoriable = 0;
#endif
static_assert(nDim == SrcDesc::GetNumOfDimension() && static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() && nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
nDim == SrcDstDimAccessOrder::Size(), nDim == SrcDstDimAccessOrder::Size(),
...@@ -71,8 +74,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -71,8 +74,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__ void Run(const SrcData* p_src, DstData* p_dst) const __device__ void Run(const SrcData* p_src, DstData* p_dst) const
{ {
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{}; constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
#if (CK_VECTORX4_FLAG|| CK_VECTORX2_FLAG)
#else
constexpr auto src_data_per_access = Number<SrcDataPerRead>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
#endif
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{}; constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
...@@ -96,16 +101,200 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -96,16 +101,200 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
p_src_long_vector[i] = 0; p_src_long_vector[i] = 0;
} }
auto scalar_id = make_zero_array<index_t, nDim>();
auto src_coord = mSrcSliceOrigin + long_vector_data_begin_id;
#if CK_VECTORX4_FLAG //vectorloadx4
{
//auto src_coord = mSrcSliceOrigin + long_vector_data_begin_id;
scalar_id(vector_access_dim) = 3;
if((long_vector_size==4) && src_coord.IsOffsetValidAssumingUpperIndexIsValid()&&((src_coord.CalculateOffsetDiff(scalar_id))==3)){
transfer_data<SrcData,
4,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, 0);
}else{
//original code
// load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size; ++i)
{
scalar_id(vector_access_dim) = i;
src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
transfer_data<SrcData,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, i);
}
}
}
}
#elif CK_VECTORX2_FLAG //vectorloadx2
//auto scalar_id = make_zero_array<index_t, nDim>();
//auto src_coord = mSrcSliceOrigin + long_vector_data_begin_id;
scalar_id(vector_access_dim) = 1;
if((long_vector_size==2) && src_coord.IsOffsetValidAssumingUpperIndexIsValid()&&((src_coord.CalculateOffsetDiff(scalar_id))==1)){
transfer_data<SrcData,
2,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, 0);
}
else{
//original code
// load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size; ++i)
{
scalar_id(vector_access_dim) = i;
src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
transfer_data<SrcData,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, i);
}
}
}
#elif CK_VECTORIZE_FLAG
if(vectorize_flag){
if(isVectoriable){
transfer_data<SrcData,
4,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, 0);
}else{
#if 1 //for loop
// load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size; ++i)
{
scalar_id(vector_access_dim) = i;
src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
transfer_data<SrcData,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, i);
}
}
#else
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
transfer_data<SrcData,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, 0);
}
scalar_id(vector_access_dim) = 1;
src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
transfer_data<SrcData,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, 1);
}
scalar_id(vector_access_dim) = 2;
src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
transfer_data<SrcData,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, 2);
}
scalar_id(vector_access_dim) = 3;
src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
transfer_data<SrcData,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, 3);
}
#endif
}
}else{
//original code
// load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{
scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access;
src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
}
}
}
#else //original code
// load data from src to the long-vector buffer // load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); //auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access; scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access; const index_t buffer_offset = i * src_data_per_access;
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
//const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src // Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector // vector. It's user's responsiblity to make sure all data in the src vector
...@@ -120,7 +309,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -120,7 +309,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
} }
} }
#endif
// SrcData to DstData conversion // SrcData to DstData conversion
DstData p_dst_long_vector[long_vector_size]; DstData p_dst_long_vector[long_vector_size];
...@@ -132,7 +321,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -132,7 +321,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// store data from the long-vector buffer to dst // store data from the long-vector buffer to dst
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i) for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
{ {
auto scalar_id = make_zero_array<index_t, nDim>(); //auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access; scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access; const index_t buffer_offset = i * dst_data_per_access;
...@@ -242,7 +431,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -242,7 +431,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const index_t buffer_offset = i * src_data_per_access; const index_t buffer_offset = i * src_data_per_access;
// move src cooridnate along linear dimensions // move src cooridnate along linear dimensionsls
const auto src_coord = const auto src_coord =
src_nonlinear_coord + (linear_dim_data_steps + scalar_id); src_nonlinear_coord + (linear_dim_data_steps + scalar_id);
...@@ -492,9 +681,41 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -492,9 +681,41 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; }); }).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
} }
#if CK_VECTORIZE_FLAG
__device__ void SetVectorizeFlag()
{
vectorize_flag = 1;
auto scalar_id = make_zero_array<index_t, nDim>();
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
auto vectoriableFlagArray = Array<index_t, (SliceLengths::Get(vector_access_dim) / SrcDataPerRead)>{};
auto mTempSrc_coord = mSrcSliceOrigin;
for(int i=0; i<vectoriableFlagArray.Size();i++){
scalar_id(vector_access_dim) = (SrcDataPerRead-1);
if(mTempSrc_coord.IsOffsetValidAssumingUpperIndexIsValid()&&((mTempSrc_coord.CalculateOffsetDiff(scalar_id))==(SrcDataPerRead-1))){
vectoriableFlagArray.At(i) = 1;
}else{
vectoriableFlagArray.At(i) = 0;
}
scalar_id(vector_access_dim) = SrcDataPerRead;
mTempSrc_coord = mTempSrc_coord + scalar_id;
}
isVectoriable = 1;
for(int i=0; i<vectoriableFlagArray.Size();i++){
if(vectoriableFlagArray.At(i)==0){
isVectoriable = 0;
}
}
}
#endif
private: private:
SrcCoord mSrcSliceOrigin; SrcCoord mSrcSliceOrigin;
DstCoord mDstSliceOrigin; DstCoord mDstSliceOrigin;
//SrcCoord mTempSrc_coord;
#if CK_VECTORIZE_FLAG
int vectorize_flag;
int isVectoriable;
#endif
}; };
} // namespace ck } // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment