Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
712babe4
Commit
712babe4
authored
Apr 21, 2021
by
Chao Liu
Browse files
replacing array with vector for tensor data
parent
03f7892a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
20 deletions
+35
-20
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+11
-8
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+1
-1
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+19
-7
composable_kernel/include/utility/buffer.hpp
composable_kernel/include/utility/buffer.hpp
+4
-4
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
712babe4
...
...
@@ -546,6 +546,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
FloatA
p_a_thread
[
a_thread_mtx_desc_
.
GetElementSpaceSize
()];
FloatB
p_b_thread
[
b_thread_mtx_desc_
.
GetElementSpaceSize
()];
auto
a_thread_buf
=
make_dynamic_buffer
<
FloatA
>
(
p_a_thread
);
auto
b_thread_buf
=
make_dynamic_buffer
<
FloatB
>
(
p_b_thread
);
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
FloatA
,
FloatB
,
FloatC
,
...
...
@@ -559,7 +562,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_a_block
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
p_
a_thread
);
a_thread
_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BlockMatrixB
{},
...
...
@@ -567,7 +570,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_b_block
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
p_
b_thread
);
b_thread
_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BlockMatrixB
{},
...
...
@@ -575,7 +578,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_b_block
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}),
p_
b_thread
);
b_thread
_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
BlockMatrixA
{},
...
...
@@ -583,7 +586,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_a_block
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
MPerThreadSubC
>
{}),
p_
a_thread
);
a_thread
_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
...
...
@@ -602,7 +605,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_a_block
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
p_
a_thread
);
a_thread
_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
...
...
@@ -616,7 +619,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_b_block
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
p_
b_thread
);
b_thread
_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
...
...
@@ -631,7 +634,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_b_block
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}),
p_
b_thread
);
b_thread
_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
BlockMatrixA
{},
...
...
@@ -639,7 +642,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
p_a_block
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
MPerThreadSubC
>
{}),
p_
a_thread
);
a_thread
_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
712babe4
...
...
@@ -52,7 +52,7 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
}
#endif
#if
1
#if
0
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
712babe4
...
...
@@ -1362,13 +1362,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
}
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstRefToOriginDisplacement
>
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstRefToOriginDisplacement
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcRefToOriginDisplacement
&
,
const
SrcData
*
p_src
,
const
DstDesc
&
,
const
DstRefToOriginDisplacement
&
,
Dst
Data
*
p_dst
)
const
Dst
Buffer
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
...
...
@@ -1450,8 +1452,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
move_dynamic_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_iterator
);
// copy data from src into buffer
S
tatic
B
uffer
<
SrcData
,
SrcScalarPerVector
>
src_buf
;
// copy data from src
_buf
into
src_tmp_
buffer
auto
src_tmp_buf
=
make_s
tatic
_b
uffer
<
SrcData
>
(
Number
<
SrcScalarPerVector
>
{})
;
using
src_vector_t
=
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
...
...
@@ -1459,18 +1461,28 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_data_coord
);
src_buf
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_
tmp_
buf
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_data_coord
.
GetOffset
()])
:
src_vector_t
{
0
};
// copy data from buffer into dst
// copy data from src_tmp_buf to dst_tmp_buf (data cast data from SrcData to DstData)
auto
dst_tmp_buf
=
make_static_buffer
<
DstData
>
(
Number
<
SrcScalarPerVector
>
{});
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_buf
.
template
AsType
<
DstData
>()(
i
)
=
static_cast
<
DstData
>
(
src_tmp_buf
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_buf into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_ref_to_origin_disp_idx
)
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
p_dst
[
Number
<
dst_offset
>
{}]
=
src_buf
.
template
AsType
<
SrcData
>()[
i
];
dst_buf
.
template
AsType
<
DstData
>()(
Number
<
dst_offset
>
{})
=
dst_tmp_buf
.
template
AsType
<
DstData
>()[
i
];
});
});
}
...
...
composable_kernel/include/utility/buffer.hpp
View file @
712babe4
...
...
@@ -20,7 +20,7 @@ struct StaticBuffer : public vector_type<ScalarType, N>
template
<
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
using
scalar_t
=
scalar_type
<
T
>
;
using
scalar_t
=
typename
scalar_type
<
T
>
::
type
;
constexpr
index_t
scalar_per_vector
=
scalar_type
<
T
>::
vector_size
;
return
StaticBuffer
<
scalar_t
,
N
*
scalar_per_vector
>
{};
...
...
@@ -51,7 +51,7 @@ struct DynamicBuffer
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
ScalarType
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
AsType
()
const
{
return
PointerWrapper
<
X
>
{
reinterpret_cast
<
X
*>
(
p_scalar_
)};
}
...
...
@@ -61,7 +61,7 @@ struct DynamicBuffer
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
ScalarType
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
&
AsType
()
__host__
__device__
constexpr
auto
AsType
()
{
return
PointerWrapper
<
X
>
{
reinterpret_cast
<
X
*>
(
p_scalar_
)};
}
...
...
@@ -70,7 +70,7 @@ struct DynamicBuffer
template
<
typename
T
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
)
{
using
scalar_t
=
scalar_type
<
T
>
;
using
scalar_t
=
typename
scalar_type
<
T
>
::
type
;
constexpr
index_t
scalar_per_vector
=
scalar_type
<
T
>::
vector_size
;
return
DynamicBuffer
<
scalar_t
>
{
p
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment