Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e38c1b73
Commit
e38c1b73
authored
Apr 17, 2021
by
Chao Liu
Browse files
replacing array with vector for tensor data
parent
841b1480
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
24 deletions
+35
-24
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+10
-16
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+25
-8
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
e38c1b73
...
@@ -130,13 +130,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
...
@@ -130,13 +130,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
// thread A, B for GEMM
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{})
)
;
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{})
)
;
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
Size
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
Size
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
decltype
(
a_thread_mtx
),
...
@@ -153,37 +153,31 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
...
@@ -153,37 +153,31 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_mtx
),
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
)
>
{};
decltype
(
c_thread_mtx
)
>
{};
#pragma unroll
// loop over k
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
static_for
<
0
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k_begin
)
{
{
#pragma unroll
// read A
// read A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m_repeat
)
{
{
a_thread_copy
.
Run
(
p_a_block
+
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
))
+
make_tuple
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
))
+
mMyThreadOffsetA
,
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
m_repeat
*
MPerThreadSubC
)));
make_tuple
(
0
,
m_repeat
*
MPerThreadSubC
)));
}
}
);
#pragma unroll
// read B
// read B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n_repeat
)
{
{
b_thread_copy
.
Run
(
p_b_block
+
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
))
+
make_tuple
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
))
+
mMyThreadOffsetB
,
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
n_repeat
*
NPerThreadSubC
)));
make_tuple
(
0
,
n_repeat
*
NPerThreadSubC
)));
}
}
);
// C += A * B
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
);
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
e38c1b73
...
@@ -27,10 +27,14 @@ struct lambda_scalar_step_in_vector
...
@@ -27,10 +27,14 @@ struct lambda_scalar_step_in_vector
}
}
};
};
// Assume:
// 1. src_desc is known at compile-time
// 2. dst_desc is not known at compile-time
// 3. src_slice_origin_idx is known at compile-time and it's 0
// 4. dst_slice_origin_idx is not-known at compile time
// this version is less likely to have scratch memory issue, due to:
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// 2. It does not construct new tensor coordinate for this->Run()
// Assume src_slice_origin_idx is 0
// TODO: support non-zero src_slice_oring_idx
// TODO: support non-zero src_slice_oring_idx
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
...
@@ -359,10 +363,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -359,10 +363,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
DstCoord
dst_slice_origin_coord_
;
DstCoord
dst_slice_origin_coord_
;
};
// namespace ck
};
// namespace ck
// Assume:
// 1. src_desc is not known at compile-time
// 2. dst_desc is known at compile-time
// 3. src_slice_origin_idx is not known at compile-time
// 4. dst_slice_origin_idx is known at compile-time and it's 0
// this version is less likely to have scratch memory issue, due to:
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// 2. It does not construct new tensor coordinate for this->Run()
// Assume dst_slice_origin_idx is 0
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
...
@@ -590,7 +598,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -590,7 +598,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}
}
}
}
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
DstData
*
p_dst
)
template
<
typename
DstSliceOriginIdx
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstData
*
p_dst
)
{
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
...
@@ -600,7 +613,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -600,7 +613,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
src_desc
,
p_src
,
p_dst
,
src_iterator_hacks
);
Run
(
src_desc
,
p_src
,
DstDesc
{},
DstSliceOriginIdx
{},
p_dst
,
src_iterator_hacks
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
@@ -685,12 +698,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -685,12 +698,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
SrcCoord
src_slice_origin_coord_
;
SrcCoord
src_slice_origin_coord_
;
};
// namespace ck
};
// namespace ck
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. src_slice_origin and dst_slice_origin are not known at compile-time,
// 3. Use thread buffer
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
//
1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
//
2. It does not construct new tensor coordinate for this->Run()
// 3. It does not use pointer for VGPR thread buffer
//
3. It does not use pointer for VGPR thread buffer
// 4. It calculate offset for thread buffer directly, instead of moving the coordinate
//
4. It calculate offset for thread buffer directly, instead of moving the coordinate
template
<
typename
SliceLengths
,
template
<
typename
SliceLengths
,
InMemoryDataOperation
DstInMemOp
,
InMemoryDataOperation
DstInMemOp
,
typename
SrcData
,
typename
SrcData
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment