Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7d0a5412
Commit
7d0a5412
authored
Mar 13, 2021
by
root
Browse files
threadwise transfer
parent
b3a012bc
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
252 additions
and
486 deletions
+252
-486
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+105
-277
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+140
-176
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+2
-1
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
+3
-30
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+1
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
7d0a5412
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
7d0a5412
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
7d0a5412
...
...
@@ -535,7 +535,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
p_dst
[
Number
<
dst_offset
>
{}]
=
src_vector
[
i
];
// p_dst[Number<dst_offset>{}] = src_vector[i];
p_dst
[
Number
<
dst_offset
>
{}]
=
src_vector
.
Scalars
()(
i
);
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
View file @
7d0a5412
...
...
@@ -28,33 +28,6 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread
});
}
template
<
typename
SrcDesc
,
typename
DstDesc
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy_v3
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
type
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template
<
typename
ADesc
,
...
...
@@ -75,9 +48,9 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
CDesc
{}
[
I0
]
;
constexpr
auto
N
=
CDesc
{}
[
I1
]
;
constexpr
auto
K
=
ADesc
{}
[
I0
]
;
constexpr
auto
M
=
CDesc
{}
.
GetLength
(
I0
)
;
constexpr
auto
N
=
CDesc
{}
.
GetLength
(
I1
)
;
constexpr
auto
K
=
ADesc
{}
.
GetLength
(
I0
)
;
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
7d0a5412
...
...
@@ -76,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr
index_t
GemmMPerThread
=
16
;
constexpr
index_t
GemmNPerThread
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmKPerThread
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
1
;
...
...
driver/src/conv_driver.cpp
View file @
7d0a5412
...
...
@@ -779,7 +779,7 @@ int main(int argc, char* argv[])
#if 1
// LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
// LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange
(
std
::
cout
<<
"out_nkhw_host : "
,
out_nkhw_host
.
mData
,
","
)
<<
std
::
endl
;
//
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange
(
std
::
cout
<<
"out_nkhw_device: "
,
out_nkhw_device
.
mData
,
","
)
<<
std
::
endl
;
#endif
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment