Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7d0a5412
Commit
7d0a5412
authored
Mar 13, 2021
by
root
Browse files
threadwise transfer
parent
b3a012bc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
252 additions
and
486 deletions
+252
-486
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+105
-277
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+140
-176
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+2
-1
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
+3
-30
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+1
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
7d0a5412
...
@@ -9,27 +9,27 @@ namespace ck {
...
@@ -9,27 +9,27 @@ namespace ck {
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// If following number are power of 2, index calculation shall be greatly reduced:
//
M
PerThread
SubC
,
N
PerThread
SubC
, MLevel0ThreadCluster, NLevel0ThreadCluster,
//
K
PerThread,
H
PerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
BlockMatrixA
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
typename
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
KPerThread
,
index_t
NPerThreadSubC
,
index_t
HPerThread
,
index_t
KPerThreadLoop
,
index_t
WPerThread
,
index_t
MLevel0ThreadCluster
,
index_t
CYXPerThreadLoop
,
index_t
NLevel0ThreadCluster
,
index_t
HThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
WThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
ThreadGemmADataPerRead_K
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_W
>
index_t
ThreadGemmBDataPerRead_N
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v3
struct
BlockwiseGemm_km_kn_m0m1n0n1_v3
{
{
struct
MatrixIndex
struct
MatrixIndex
{
{
index_t
row
;
index_t
k
;
index_t
col
;
index_t
h
;
index_t
w
;
};
};
index_t
mMyThreadOffsetA
;
index_t
mMyThreadOffsetA
;
...
@@ -44,325 +44,153 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -44,325 +44,153 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
//
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster
*
NLevel1ThreadCluster
;
//
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert
(
BlockSize
==
Thread
PerLevel1
Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockSize
==
H
Thread
Cluster
*
WThread
Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
K
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
static_assert
(
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
K
%
(
KPerThread
)
==
0
&&
"wrong! Cannot evenly divide work among
\n
"
);
(
N
*
H
*
W
)
%
(
HPerThread
*
WPerThread
*
HThreadCluster
*
WThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
static_assert
(
ThreadMatrixC
{}.
GetLength
(
I0
)
==
GetThreadMatrixCLengths
()[
I0
]
&&
ThreadMatrixC
{}.
GetLength
(
I1
)
==
GetThreadMatrixCLengths
()[
I1
],
"wrong! ThreadMatrixC lengths is wrong"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
k
));
mMyThreadOffsetB
=
BlockMatrixB
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
col
));
mMyThreadOffsetB
=
BlockMatrixB
{}.
CalculateOffset
(
make_tuple
(
0
,
0
,
c_thread_mtx_index
.
h
,
c_thread_mtx_index
.
w
));
}
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
{
{
constexpr
auto
I1
=
Number
<
1
>
{};
return
Sequence
<
KPerThread
,
1
,
HPerThread
,
WPerThread
>
{};
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
);
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
}
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
{
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
;
return
MatrixIndex
{
1
,
8
,
8
};
index_t
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1ThreadCluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1ThreadCluster
;
index_t
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0ThreadCluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0ThreadCluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
SrcDesc
,
__device__
void
typename
DstDesc
,
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseSliceCopy_a
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Data
>
constexpr
auto
I1
=
Number
<
1
>
{};
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
"wrong! Desc should be known at compile-time"
);
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
type
;
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpaceSize
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v3
<
BlockMatrixA
,
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
decltype
(
a_thread_mtx
),
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
KPerThreadLoop
,
constexpr
auto
src_offset
=
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
MPerThreadSubC
,
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v3
<
BlockMatrixB
,
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
decltype
(
b_thread_mtx
),
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
KPerThreadLoop
,
});
NPerThreadSubC
,
});
ThreadGemmBDataPerRead_N
>
{};
}
};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_mtx
),
template
<
typename
SrcDesc
,
decltype
(
b_thread_mtx
),
typename
DstDesc
,
decltype
(
c_thread_mtx
)
>
{};
index_t
NSliceCYX
,
#pragma unroll
index_t
NSliceH
,
// loop over k
index_t
NSliceW
,
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
index_t
DataPerAccess
>
struct
ThreadwiseSliceCopy_b
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
{
#pragma unroll
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
// read A
"wrong! Desc should be known at compile-time"
);
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
type
;
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
static_for
<
0
,
NSliceCYX
,
1
>
{}([
&
](
auto
i
)
{
make_tuple
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
))
+
static_for
<
0
,
NSliceH
,
1
>
{}([
&
](
auto
j
)
{
mMyThreadOffsetA
,
static_for
<
0
,
NSliceW
,
1
>
{}([
&
](
auto
k
)
{
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
constexpr
auto
src_offset
=
make_tuple
(
0
,
m_repeat
*
MPerThreadSubC
)));
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
0
,
j
,
k
));
}
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
0
,
j
,
k
));
#pragma unroll
// read B
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
{
});
b_thread_copy
.
Run
(
p_b_block
+
});
b_block_mtx
.
CalculateOffset
(
});
make_tuple
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
))
+
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
n_repeat
*
NPerThreadSubC
)));
}
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
}
}
;
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
__device__
void
Run_
pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_
block
,
FloatC
*
p_c_thread
)
const
Run_
naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_
thread
,
FloatC
*
p_c_thread
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
constexpr
auto
CYXPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
// thread A, B for GEMM
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A, B
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
K
PerThreadLoop
>
{},
Number
<
M
PerThread
>
{}));
make_tuple
(
Number
<
CYX
PerThreadLoop
>
{},
Number
<
K
PerThread
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{}));
make_tuple
(
Number
<
CYXPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
1
>
{}));
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
1
>
{}));
constexpr
auto
b_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
constexpr
auto
c_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{}));
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
constexpr
auto
c_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpaceSize
()];
constexpr
auto
a_thread_copy
=
Threadwise
Matrix
SliceCopy_
v3
<
BlockMatrixA
,
constexpr
auto
a_thread_copy
=
ThreadwiseSliceCopy_
a
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
decltype
(
a_thread_mtx
),
K
PerThreadLoop
,
CYX
PerThreadLoop
,
M
PerThread
SubC
,
K
PerThread
,
ThreadGemmADataPerRead_
M
>
{};
ThreadGemmADataPerRead_
K
>
{};
constexpr
auto
b_
thread
_copy
=
Threadwise
MatrixSliceCopy_v3
<
BlockMatrixB
,
constexpr
auto
thread
wise_gemm
=
Threadwise
Gemm_km_kn_mn_v3
<
decltype
(
a_thread_mtx
)
,
decltype
(
b_thread_mtx
),
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
decltype
(
c_thread_mtx
)
>
{};
NPerThreadSubC
,
// loop over k
ThreadGemmBDataPerRead_N
>
{};
for
(
index_t
cyx_begin
=
0
;
cyx_begin
<
CYXPerBlock
;
cyx_begin
+=
CYXPerThreadLoop
)
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
const
FloatA
*
p_a_block_off
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatB
*
p_b_block_off
=
p_b_block
+
mMyThreadOffsetB
;
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
,
p_a_thread
);
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
cyx_begin
,
0
))
+
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
mMyThreadOffsetA
,
p_a_thread
);
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
0
)));
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_b_thread
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
threadwise_gemm
.
Run
(
}
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
{
#if 0
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0);
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
7d0a5412
...
@@ -18,12 +18,12 @@ template <index_t BlockSize,
...
@@ -18,12 +18,12 @@ template <index_t BlockSize,
typename
AGlobalDesc
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
M
Per
Thread
,
index_t
HW
Per
Block
,
index_t
NPerThread
,
index_t
CYXPerBlock
,
index_t
KPerThread
,
index_t
KPerThread
,
index_t
HWPerThread
,
index_t
CYXPerThread
,
index_t
MLevel0Cluster
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
MLevel1Cluster
,
...
@@ -58,31 +58,34 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -58,31 +58,34 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
{
{
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
MPerThread
>
{},
Number
<
KPerThread
>
{},
Number
<
NPerThread
>
{});
Number
<
HWPerThread
>
{});
static_assert
(
CYXPerBlock
==
4
&&
HWPerBlock
==
64
&&
KPerBlock
==
16
,
""
);
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_
k_m
_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_
cyx_k
_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
K
PerBlock
>
{},
Number
<
M
PerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
CYX
PerBlock
>
{},
Number
<
K
PerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_cyx_n_h_w_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_cyx_n_h_w_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
1
>
{},
Number
<
8
>
{},
Number
<
8
>
{}),
max_lds_align
);
make_tuple
(
Number
<
CYXPerBlock
>
{},
Number
<
1
>
{},
Number
<
8
>
{},
Number
<
8
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_
k_m
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_
cyx_k
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
math
::
integer_least_multiple
(
b_cyx_n_h_w_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
b_cyx_n_h_w_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_
k_m
_global_desc
,
__device__
void
Run
(
const
AGlobalDesc
&
a_
cyx_k
_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_cyx_n_h_w_global_desc
,
const
BGlobalDesc
&
b_cyx_n_h_w_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
...
@@ -94,62 +97,70 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -94,62 +97,70 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
CYX
=
a_cyx_k_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_cyx_k_global_desc
.
GetLength
(
I1
);
static_assert
(
CYX
==
4
*
3
*
3
&&
K
==
16
,
""
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_cyx_n_h_w_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_cyx_n_h_w_global_desc
.
GetLength
(
I1
);
const
auto
H
=
b_cyx_n_h_w_global_desc
.
GetLength
(
I2
);
const
auto
W
=
b_cyx_n_h_w_global_desc
.
GetLength
(
I3
);
// divide block work by [M, N]
// divide block work by [M, N]
#if
0
#if
1
const auto m_block_work_num =
M
/ Number<
M
PerBlock>{};
const
auto
m_block_work_num
=
K
/
Number
<
K
PerBlock
>
{};
const auto n_block_work_num = N / Number<
N
PerBlock>{};
const
auto
n
hw
_block_work_num
=
(
N
*
H
*
W
)
/
Number
<
HW
PerBlock
>
{};
const index_t
m
_block_work_id = get_block_1d_id() / n_block_work_num;
const
index_t
k
_block_work_id
=
get_block_1d_id
()
/
n
hw
_block_work_num
;
const index_t n_block_work_id = get_block_1d_id() -
m
_block_work_id * n_block_work_num;
const
index_t
n
hw
_block_work_id
=
get_block_1d_id
()
-
k
_block_work_id
*
n
hw
_block_work_num
;
#else
#else
// Hack: this force result into SGPR
// Hack: this force result into SGPR
const
index_t
m_block_work_num
=
__builtin_amdgcn_readfirstlane
(
M
/
M
PerBlock
);
const
index_t
m_block_work_num
=
__builtin_amdgcn_readfirstlane
(
K
/
K
PerBlock
);
const
index_t
n_block_work_num
=
__builtin_amdgcn_readfirstlane
(
N
/
N
PerBlock
);
const
index_t
n
hw
_block_work_num
=
__builtin_amdgcn_readfirstlane
(
N
/
HW
PerBlock
);
const
index_t
m
_block_work_id
=
const
index_t
k
_block_work_id
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
n_block_work_num
);
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
n
hw
_block_work_num
);
const
index_t
n_block_work_id
=
get_block_1d_id
()
-
m
_block_work_id
*
n_block_work_num
;
const
index_t
n
hw
_block_work_id
=
get_block_1d_id
()
-
k
_block_work_id
*
n
hw
_block_work_num
;
#endif
#endif
const
index_t
m_block_data_on_global
=
m
_block_work_id
*
M
PerBlock
;
const
index_t
m_block_data_on_global
=
k
_block_work_id
*
K
PerBlock
;
const
index_t
h_block_data_on_global
=
n_block_work_id
*
8
;
const
index_t
h_block_data_on_global
=
n
hw
_block_work_id
*
8
;
const
index_t
w_block_data_on_global
=
n_block_work_id
*
8
;
const
index_t
w_block_data_on_global
=
n
hw
_block_work_id
*
8
;
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
M
PerThread
>
{},
Number
<
K
PerThread
>
{},
Number
<
N
PerThread
>
{});
Number
<
HW
PerThread
>
{});
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_
k_m
_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_
cyx_k
_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
K
PerBlock
>
{},
Number
<
M
PerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
CYX
PerBlock
>
{},
Number
<
K
PerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_cyx_n_h_w_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_cyx_n_h_w_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
1
>
{},
Number
<
8
>
{},
Number
<
8
>
{}),
max_lds_align
);
make_tuple
(
Number
<
CYXPerBlock
>
{},
Number
<
1
>
{},
Number
<
8
>
{},
Number
<
8
>
{}),
max_lds_align
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
Sequence
<
K
PerBlock
,
M
PerBlock
>
,
Sequence
<
CYX
PerBlock
,
K
PerBlock
>
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
Float
,
Float
,
Float
,
Float
,
decltype
(
a_
k_m
_global_desc
),
decltype
(
a_
cyx_k
_global_desc
),
decltype
(
a_
k_m
_block_desc
),
decltype
(
a_
cyx_k
_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
...
@@ -162,101 +173,65 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -162,101 +173,65 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
a_
k_m
_global_desc
,
a_
cyx_k
_global_desc
,
make_multi_index
(
0
,
m_block_data_on_global
),
make_multi_index
(
0
,
m_block_data_on_global
),
a_
k_m
_block_desc
,
a_
cyx_k
_block_desc
,
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
// B matrix blockwise copy
#if 1
auto
b_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
constexpr
auto
b_cyx_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
BlockSize
,
make_tuple
(
Number
<
CYXPerThread
>
{},
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
1
>
{}));
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
1
,
8
,
8
>
,
// BlockSliceLengths
const
index_t
h_thread_id
=
get_thread_local_1d_id
()
/
8
;
Sequence
<
KPerBlock
,
1
,
1
,
1
>
,
// ThreadSliceLengths_K_N
const
index_t
w_thread_id
=
get_thread_local_1d_id
()
%
8
;
Sequence
<
1
,
1
,
8
,
8
>
,
// ThreadClusterLengths_K_N
Sequence
<
3
,
2
,
0
,
1
>
,
// ThreadClusterArrangeOrder
auto
b_threadwise_transfer
=
ThreadwiseDynamicTensorSliceTransfer_v2
<
Float
,
Float
,
Float
,
Float
,
decltype
(
b_cyx_n_h_w_global_desc
),
// SrcDesc
decltype
(
b_cyx_n_h_w_global_desc
),
decltype
(
b_cyx_n_h_w_block_desc
),
// DstDesc
decltype
(
b_cyx_n_h_w_thread_desc
),
Sequence
<
3
,
2
,
0
,
1
>
,
// SrcDimAccessOrder
Sequence
<
CYXPerThread
,
1
,
1
,
1
>
,
Sequence
<
3
,
2
,
0
,
1
>
,
// DstDimAccessOrder
Sequence
<
3
,
2
,
0
,
1
>
,
// BBlockTransferSrcAccessOrder,
3
,
// SrcVectorDim
3
,
// BBlockTransferSrcVectorDim,
3
,
// DstVectorDim
1
,
// BBlockTransferSrcScalarPerVector,
1
,
// SrcScalarPerVector
1
,
// DstScalarPerVector
AddressSpace
::
Global
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
AddressSpace
::
Vgpr
,
1
,
InMemoryDataOperation
::
Set
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b_cyx_n_h_w_global_desc
,
b_cyx_n_h_w_global_desc
,
make_multi_index
(
0
,
0
,
h_block_data_on_global
,
w_block_data_on_global
),
make_multi_index
(
b_cyx_n_h_w_block_desc
,
0
,
0
,
h_block_data_on_global
+
h_thread_id
,
w_block_data_on_global
+
w_thread_id
));
make_multi_index
(
0
,
0
,
0
,
0
));
#if 0
constexpr auto b_cyx_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<NPerThread>{}));
using BThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v2<Float,
Float,
decltype(b_cyx_n_h_w_global_desc),
decltype(b_cyx_n_h_w_thread_desc),
Sequence<KPerThread, NPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>;
#endif
#endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
//static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
//NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
//"wrong!");
// constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
// constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
c_k_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
M
PerThread
>
{},
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
1
>
{}));
make_tuple
(
Number
<
K
PerThread
>
{},
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
1
>
{}));
#if
0
#if
1
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype(a_
k_m
_block_desc),
decltype
(
a_
cyx_k
_block_desc
),
decltype(b_cyx_n_h_w_block_desc),
decltype
(
b_cyx_n_h_w_block_desc
),
decltype(c_k_n_h_w_thread_desc),
decltype
(
c_k_n_h_w_thread_desc
),
M
PerThread
,
16
,
// K
PerThread
SubC
N
PerThread
,
1
,
// H
PerThread
SubC
K
PerThread
,
1
,
// W
PerThread
SubC
MLevel0Cluster,
1
,
// CYXPerThreadLoop
NLevel0
Cluster
,
8
,
// HThread
Cluster
MLevel1
Cluster
,
8
,
// WThread
Cluster
NLevel1Cluster,
1
,
// ThreadGemmADataPerRead_K
1,
1
// ThreadGemmBDataPerRead_W
1
>{};
>
{};
#endif
#endif
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_
k_m
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_
cyx_k
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
math
::
integer_least_multiple
(
b_cyx_n_h_w_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
b_cyx_n_h_w_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
...
@@ -272,11 +247,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -272,11 +247,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// zero out threadwise output
// zero out threadwise output
// threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread);
// threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K
PerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
CYX
PerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K
PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
CYX
PerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_cyx_n_h_w_global_iterator_hacks
=
BGlobalIteratorHacks
{};
constexpr
auto
b_cyx_n_h_w_global_iterator_hacks
=
BGlobalIteratorHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
...
@@ -288,13 +263,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -288,13 +263,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_cyx_k_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_cyx_n_h_w_global_desc
,
p_b_global
,
b_cyx_n_h_w_global_iterator_hacks
);
constexpr
auto
b_thread_mtx
=
b_cyx_n_h_w_thread_desc
;
Float
p_b_thread
[
b_thread_mtx
.
GetElementSpaceSize
()];
b_threadwise_transfer
.
Run
(
b_cyx_n_h_w_global_desc
,
p_b_global
,
b_cyx_n_h_w_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_b_thread
,
b_cyx_n_h_w_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
);
a_blockwise_copy
.
RunWrite
(
a_cyx_k_block_desc
,
p_a_block_double
);
b_blockwise_copy
.
RunWrite
(
b_cyx_n_h_w_block_desc
,
p_b_block_double
);
__syncthreads
();
}
}
#if 0
if constexpr(HasMainKBlockLoop)
if constexpr(HasMainKBlockLoop)
{
{
Float* p_a_block_even = p_a_block_double;
Float* p_a_block_even = p_a_block_double;
...
@@ -303,104 +290,82 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -303,104 +290,82 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t
k
_block_data_begin
=
0
;
index_t
b
_block_data_begin = 0;
// LDS double buffer: main body
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
// use Do-While loop instead of For loop to simplify control flow
do
do
{
{
// even iteration
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_
k_m
_global_desc
,
a_blockwise_copy.MoveSrcSliceWindow(a_
cyx_k
_global_desc,
a_block_slice_copy_step,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
a_k_m_global_move_slice_window_iterator_hack);
// b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
// b_block_slice_copy_step,
// b_cyx_n_h_w_global_move_slice_window_iterator_hack);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_cyx_n_h_w_global_desc
,
b_block_slice_copy_step
,
b_cyx_n_h_w_global_move_slice_window_iterator_hack
);
__syncthreads();
__syncthreads();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_blockwise_copy.RunRead(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy
.
RunRead
(
b_cyx_n_h_w_global_desc
,
p_b_global
,
b_cyx_n_h_w_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
//
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_odd);
b_blockwise_copy
.
RunWrite
(
b_cyx_n_h_w_block_desc
,
p_b_block_odd
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_
k_m
_global_desc
,
a_blockwise_copy.MoveSrcSliceWindow(a_
cyx_k
_global_desc,
a_block_slice_copy_step,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_cyx_n_h_w_global_desc
,
b_block_slice_copy_step
,
b_cyx_n_h_w_global_move_slice_window_iterator_hack
);
__syncthreads();
__syncthreads();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_blockwise_copy.RunRead(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy
.
RunRead
(
b_cyx_n_h_w_global_desc
,
p_b_global
,
b_cyx_n_h_w_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
//
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_even);
b_blockwise_copy
.
RunWrite
(
b_cyx_n_h_w_block_desc
,
p_b_block_even
);
k
_block_data_begin
+=
2
*
K
PerBlock
;
b
_block_data_begin += 2 *
CYX
PerBlock;
}
while
(
k
_block_data_begin
<
K
-
2
*
K
PerBlock
);
} while(
b
_block_data_begin <
CYX
- 2 *
CYX
PerBlock);
}
}
// LDS double buffer: tail
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_
k_m
_global_desc
,
a_blockwise_copy.MoveSrcSliceWindow(a_
cyx_k
_global_desc,
a_block_slice_copy_step,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_cyx_n_h_w_global_desc
,
b_block_slice_copy_step
,
b_cyx_n_h_w_global_move_slice_window_iterator_hack
);
__syncthreads();
__syncthreads();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy.RunRead(a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy
.
RunRead
(
b_cyx_n_h_w_global_desc
,
p_b_global
,
b_cyx_n_h_w_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
//
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy
.
RunWrite
(
b_cyx_n_h_w_block_desc
,
p_b_block_double
+
b_block_space_size
);
__syncthreads();
__syncthreads();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
//
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
//
p_b_block_double + b_block_space_size,
p_b_block_double + b_block_space_size,
//
p_c_thread);
p_c_thread);
}
}
else // if has 1 iteration left
else // if has 1 iteration left
{
{
__syncthreads();
__syncthreads();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
//
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
#endif
#if 1
#if 1
// output: register to global memory
// output: register to global memory
...
@@ -408,7 +373,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -408,7 +373,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// define input tensor descriptor for threadwise copy
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
c_k_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
constexpr
auto
c_k_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
M
PerThread
>
{},
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
1
>
{}));
make_tuple
(
Number
<
K
PerThread
>
{},
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
1
>
{}));
// calculate origin of thread input tensor on global memory
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
...
@@ -432,15 +397,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -432,15 +397,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// hack to control index calculation when iterating over c_k_n_h_w_global tensor
// hack to control index calculation when iterating over c_k_n_h_w_global tensor
constexpr
auto
c_k_n_h_w_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
constexpr
auto
c_k_n_h_w_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
// constexpr auto tmp = make_unmerge_transform(make_tuple(
// Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
AccFloat
,
AccFloat
,
Float
,
Float
,
decltype
(
c_k_n_h_w_thread_desc
),
decltype
(
c_k_n_h_w_thread_desc
),
decltype
(
c_k_n_h_w_global_desc
),
decltype
(
c_k_n_h_w_global_desc
),
Sequence
<
M
PerThread
,
1
,
1
,
1
>
,
Sequence
<
K
PerThread
,
1
,
1
,
1
>
,
Sequence
<
3
,
2
,
0
,
1
>
,
// CThreadTransferSrcDstAccessOrder
Sequence
<
3
,
2
,
0
,
1
>
,
// CThreadTransferSrcDstAccessOrder
3
,
// CThreadTransferSrcDstVectorDim
3
,
// CThreadTransferSrcDstVectorDim
1
,
// CThreadTransferDstScalarPerVector,
1
,
// CThreadTransferDstScalarPerVector,
...
@@ -464,7 +426,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -464,7 +426,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptor by reference
// pass tensor descriptor by reference
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_
k_m
_global_desc
,
__device__
void
Run
(
const
AGlobalDesc
&
a_
cyx_k
_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_cyx_n_h_w_global_desc
,
const
BGlobalDesc
&
b_cyx_n_h_w_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
...
@@ -477,7 +439,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -477,7 +439,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
__shared__
Float
p_shared_block
[
shared_block_size
];
__shared__
Float
p_shared_block
[
shared_block_size
];
Run
(
a_
k_m
_global_desc
,
Run
(
a_
cyx_k
_global_desc
,
p_a_global
,
p_a_global
,
b_cyx_n_h_w_global_desc
,
b_cyx_n_h_w_global_desc
,
p_b_global
,
p_b_global
,
...
@@ -490,7 +452,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -490,7 +452,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptors by their pointers
// pass tensor descriptors by their pointers
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
*
p_a_
k_m
_global_desc
,
__device__
void
Run
(
const
AGlobalDesc
*
p_a_
cyx_k
_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
*
p_b_cyx_n_h_w_global_desc
,
const
BGlobalDesc
*
p_b_cyx_n_h_w_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
...
@@ -499,11 +461,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -499,11 +461,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
const
auto
a_
k_m
_global_desc
=
*
p_a_
k_m
_global_desc
;
const
auto
a_
cyx_k
_global_desc
=
*
p_a_
cyx_k
_global_desc
;
const
auto
b_cyx_n_h_w_global_desc
=
*
p_b_cyx_n_h_w_global_desc
;
const
auto
b_cyx_n_h_w_global_desc
=
*
p_b_cyx_n_h_w_global_desc
;
const
auto
c_k_n_h_w_global_desc
=
*
p_c_k_n_h_w_global_desc
;
const
auto
c_k_n_h_w_global_desc
=
*
p_c_k_n_h_w_global_desc
;
Run
(
a_
k_m
_global_desc
,
Run
(
a_
cyx_k
_global_desc
,
p_a_global
,
p_a_global
,
b_cyx_n_h_w_global_desc
,
b_cyx_n_h_w_global_desc
,
p_b_global
,
p_b_global
,
...
@@ -515,7 +477,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -515,7 +477,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptors by void*
// pass tensor descriptors by void*
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
void
*
p_a_
k_m
_global_desc
,
__device__
void
Run
(
const
void
*
p_a_
cyx_k
_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_a_global
,
const
void
*
p_b_cyx_n_h_w_global_desc
,
const
void
*
p_b_cyx_n_h_w_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
...
@@ -524,12 +486,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
...
@@ -524,12 +486,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
const
auto
a_k_m_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
(
p_a_k_m_global_desc
);
const
auto
a_cyx_k_global_desc
=
const
auto
b_cyx_n_h_w_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_cyx_n_h_w_global_desc
);
*
reinterpret_cast
<
const
AGlobalDesc
*>
(
p_a_cyx_k_global_desc
);
const
auto
b_cyx_n_h_w_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_cyx_n_h_w_global_desc
);
const
auto
c_k_n_h_w_global_desc
=
const
auto
c_k_n_h_w_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_k_n_h_w_global_desc
);
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_k_n_h_w_global_desc
);
Run
(
a_
k_m
_global_desc
,
Run
(
a_
cyx_k
_global_desc
,
p_a_global
,
p_a_global
,
b_cyx_n_h_w_global_desc
,
b_cyx_n_h_w_global_desc
,
p_b_global
,
p_b_global
,
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
7d0a5412
...
@@ -535,7 +535,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -535,7 +535,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
i
*
src_scalar_step_in_vector
);
p_dst
[
Number
<
dst_offset
>
{}]
=
src_vector
[
i
];
// p_dst[Number<dst_offset>{}] = src_vector[i];
p_dst
[
Number
<
dst_offset
>
{}]
=
src_vector
.
Scalars
()(
i
);
});
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
View file @
7d0a5412
...
@@ -28,33 +28,6 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread
...
@@ -28,33 +28,6 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread
});
});
}
}
template
<
typename
SrcDesc
,
typename
DstDesc
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy_v3
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
type
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
// Element of matrix can be vectorized data
template
<
typename
ADesc
,
template
<
typename
ADesc
,
...
@@ -75,9 +48,9 @@ struct ThreadwiseGemm_km_kn_mn_v3
...
@@ -75,9 +48,9 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
CDesc
{}
[
I0
]
;
constexpr
auto
M
=
CDesc
{}
.
GetLength
(
I0
)
;
constexpr
auto
N
=
CDesc
{}
[
I1
]
;
constexpr
auto
N
=
CDesc
{}
.
GetLength
(
I1
)
;
constexpr
auto
K
=
ADesc
{}
[
I0
]
;
constexpr
auto
K
=
ADesc
{}
.
GetLength
(
I0
)
;
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
7d0a5412
...
@@ -76,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
...
@@ -76,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr
index_t
GemmMPerThread
=
16
;
constexpr
index_t
GemmMPerThread
=
16
;
constexpr
index_t
GemmNPerThread
=
1
;
constexpr
index_t
GemmNPerThread
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmKPerThread
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
1
;
...
...
driver/src/conv_driver.cpp
View file @
7d0a5412
...
@@ -779,7 +779,7 @@ int main(int argc, char* argv[])
...
@@ -779,7 +779,7 @@ int main(int argc, char* argv[])
#if 1
#if 1
// LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
// LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
// LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
// LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange
(
std
::
cout
<<
"out_nkhw_host : "
,
out_nkhw_host
.
mData
,
","
)
<<
std
::
endl
;
//
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange
(
std
::
cout
<<
"out_nkhw_device: "
,
out_nkhw_device
.
mData
,
","
)
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"out_nkhw_device: "
,
out_nkhw_device
.
mData
,
","
)
<<
std
::
endl
;
#endif
#endif
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment