Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
71ffba65
Commit
71ffba65
authored
Jun 05, 2020
by
Jane.Zhou
Browse files
use flag to indicate if a load access can be x4 vectorized
parent
7d09790a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
239 additions
and
8 deletions
+239
-8
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+7
-0
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+4
-1
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+228
-7
No files found.
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
71ffba65
...
...
@@ -144,6 +144,13 @@ struct BlockwiseGenericTensorSliceCopy_v4
{
mThreadwiseStore
.
MoveDstSliceWindow
(
step_sizes
,
positive_direction
);
}
#if CK_VECTORIZE_FLAG
__device__
void
SetVectorizeFlag
()
{
mThreadwiseLoad
.
SetVectorizeFlag
();
}
#endif
private:
using
ThreadBufferDesc
=
decltype
(
make_native_tensor_descriptor_packed
(
ThreadSliceLengths
{}));
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
71ffba65
...
...
@@ -228,6 +228,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
#if CK_VECTORIZE_FLAG
b_blockwise_copy
.
SetVectorizeFlag
();
#endif
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
...
...
@@ -285,7 +288,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
71ffba65
...
...
@@ -5,7 +5,6 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_coordinate.hpp"
namespace
ck
{
// This threadwise copy allow vector access of src and dst.
...
...
@@ -36,6 +35,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
Index
&
dst_slice_origin
)
:
mSrcSliceOrigin
(
src_slice_origin
),
mDstSliceOrigin
(
dst_slice_origin
)
{
#if CK_VECTORIZE_FLAG
vectorize_flag
=
0
;
isVectoriable
=
0
;
#endif
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
Size
()
&&
nDim
==
SrcDstDimAccessOrder
::
Size
(),
...
...
@@ -71,8 +74,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
vector_access_dim
=
Number
<
SrcDstVectorReadWriteDim
>
{};
#if (CK_VECTORX4_FLAG|| CK_VECTORX2_FLAG)
#else
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
#endif
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPerRead
,
DstDataPerWrite
)
>
{};
...
...
@@ -96,16 +101,200 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{
p_src_long_vector
[
i
]
=
0
;
}
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
auto
src_coord
=
mSrcSliceOrigin
+
long_vector_data_begin_id
;
#if CK_VECTORX4_FLAG //vectorloadx4
{
//auto src_coord = mSrcSliceOrigin + long_vector_data_begin_id;
scalar_id
(
vector_access_dim
)
=
3
;
if
((
long_vector_size
==
4
)
&&
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
()
&&
((
src_coord
.
CalculateOffsetDiff
(
scalar_id
))
==
3
)){
transfer_data
<
SrcData
,
4
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
0
);
}
else
{
//original code
// load data from src to the long-vector buffer
for
(
index_t
i
=
0
;
i
<
long_vector_size
;
++
i
)
{
scalar_id
(
vector_access_dim
)
=
i
;
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
i
);
}
}
}
}
#elif CK_VECTORX2_FLAG //vectorloadx2
//auto scalar_id = make_zero_array<index_t, nDim>();
//auto src_coord = mSrcSliceOrigin + long_vector_data_begin_id;
scalar_id
(
vector_access_dim
)
=
1
;
if
((
long_vector_size
==
2
)
&&
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
()
&&
((
src_coord
.
CalculateOffsetDiff
(
scalar_id
))
==
1
)){
transfer_data
<
SrcData
,
2
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
0
);
}
else
{
//original code
// load data from src to the long-vector buffer
for
(
index_t
i
=
0
;
i
<
long_vector_size
;
++
i
)
{
scalar_id
(
vector_access_dim
)
=
i
;
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
i
);
}
}
}
#elif CK_VECTORIZE_FLAG
if
(
vectorize_flag
){
if
(
isVectoriable
){
transfer_data
<
SrcData
,
4
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
0
);
}
else
{
#if 1 //for loop
// load data from src to the long-vector buffer
for
(
index_t
i
=
0
;
i
<
long_vector_size
;
++
i
)
{
scalar_id
(
vector_access_dim
)
=
i
;
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
i
);
}
}
#else
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
0
);
}
scalar_id
(
vector_access_dim
)
=
1
;
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
1
);
}
scalar_id
(
vector_access_dim
)
=
2
;
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
2
);
}
scalar_id
(
vector_access_dim
)
=
3
;
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
transfer_data
<
SrcData
,
1
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
3
);
}
#endif
}
}
else
{
//original code
// load data from src to the long-vector buffer
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
scalar_id
(
vector_access_dim
)
=
i
*
src_data_per_access
;
const
index_t
buffer_offset
=
i
*
src_data_per_access
;
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
transfer_data
<
SrcData
,
SrcDataPerRead
,
SrcAddressSpace
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
buffer_offset
);
}
}
}
#else //original code
// load data from src to the long-vector buffer
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
//
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id
(
vector_access_dim
)
=
i
*
src_data_per_access
;
const
index_t
buffer_offset
=
i
*
src_data_per_access
;
const
auto
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
//const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
...
...
@@ -120,7 +309,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
buffer_offset
);
}
}
#endif
// SrcData to DstData conversion
DstData
p_dst_long_vector
[
long_vector_size
];
...
...
@@ -132,7 +321,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// store data from the long-vector buffer to dst
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
dst_data_per_access
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
//
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id
(
vector_access_dim
)
=
i
*
dst_data_per_access
;
const
index_t
buffer_offset
=
i
*
dst_data_per_access
;
...
...
@@ -242,7 +431,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
index_t
buffer_offset
=
i
*
src_data_per_access
;
// move src cooridnate along linear dimensions
// move src cooridnate along linear dimensions
ls
const
auto
src_coord
=
src_nonlinear_coord
+
(
linear_dim_data_steps
+
scalar_id
);
...
...
@@ -492,9 +681,41 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
}).
Else
([
&
](
auto
)
{
mDstSliceOrigin
-=
step_sizes
;
});
}
#if CK_VECTORIZE_FLAG
__device__
void
SetVectorizeFlag
()
{
vectorize_flag
=
1
;
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
constexpr
auto
vector_access_dim
=
Number
<
SrcDstVectorReadWriteDim
>
{};
auto
vectoriableFlagArray
=
Array
<
index_t
,
(
SliceLengths
::
Get
(
vector_access_dim
)
/
SrcDataPerRead
)
>
{};
auto
mTempSrc_coord
=
mSrcSliceOrigin
;
for
(
int
i
=
0
;
i
<
vectoriableFlagArray
.
Size
();
i
++
){
scalar_id
(
vector_access_dim
)
=
(
SrcDataPerRead
-
1
);
if
(
mTempSrc_coord
.
IsOffsetValidAssumingUpperIndexIsValid
()
&&
((
mTempSrc_coord
.
CalculateOffsetDiff
(
scalar_id
))
==
(
SrcDataPerRead
-
1
))){
vectoriableFlagArray
.
At
(
i
)
=
1
;
}
else
{
vectoriableFlagArray
.
At
(
i
)
=
0
;
}
scalar_id
(
vector_access_dim
)
=
SrcDataPerRead
;
mTempSrc_coord
=
mTempSrc_coord
+
scalar_id
;
}
isVectoriable
=
1
;
for
(
int
i
=
0
;
i
<
vectoriableFlagArray
.
Size
();
i
++
){
if
(
vectoriableFlagArray
.
At
(
i
)
==
0
){
isVectoriable
=
0
;
}
}
}
#endif
private:
SrcCoord
mSrcSliceOrigin
;
DstCoord
mDstSliceOrigin
;
//SrcCoord mTempSrc_coord;
#if CK_VECTORIZE_FLAG
int
vectorize_flag
;
int
isVectoriable
;
#endif
};
}
// namespace ck
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment