Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
f35c64eb
Commit
f35c64eb
authored
Mar 23, 2019
by
Chao Liu
Browse files
experimenting
parent
22114959
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
967 additions
and
974 deletions
+967
-974
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+2
-2
driver/driver.hip.cpp
driver/driver.hip.cpp
+2
-2
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+802
-0
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+140
-955
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
+1
-3
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+10
-6
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...mm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+10
-6
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
f35c64eb
...
@@ -160,7 +160,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -160,7 +160,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
unsigned
BlockSize
=
128
;
#elif
1
#elif
0
// 1x1, 28x28, 256 thread
// 1x1, 28x28, 256 thread
constexpr
unsigned
BPerBlock
=
128
;
constexpr
unsigned
BPerBlock
=
128
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
KPerBlock
=
128
;
...
@@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
float
time
=
launch_kernel
(
#if
0
#if
1
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#else
#else
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...
...
driver/driver.hip.cpp
View file @
f35c64eb
...
@@ -661,9 +661,9 @@ int main(int argc, char* argv[])
...
@@ -661,9 +661,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0
#elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
...
...
src/include/blockwise_batched_gemm.hip.hpp
0 → 100644
View file @
f35c64eb
#pragma once
#include "threadwise_gemm.hip.hpp"
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
bool
TransA
,
bool
TransB
,
bool
TransC
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideB
,
unsigned
ThreadMatrixStrideC
,
unsigned
BatchSize
,
unsigned
BatchPerThread
,
unsigned
KPerThreadLoop
,
bool
DistributeThreadAlongColumnFirst
>
struct
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
unsigned
batch
;
unsigned
row
;
unsigned
col
;
};
__device__
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
()
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
((
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
((
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col
,
0
));
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! k dimension not consistent!"
);
constexpr
unsigned
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
unsigned
NPerBlock
=
b_block_mtx
.
NCol
();
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
// divide thread work
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"BatchSize % BatchPerThread != 0"
);
static_assert
(
MPerBlock
%
MPerThread
==
0
,
"MPerBlock % MPerThread != 0"
);
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"NPerBlock % NPerThread != 0"
);
constexpr
unsigned
BatchThreadWork
=
(
BatchSize
+
BatchPerThread
-
1
)
/
BatchPerThread
;
constexpr
unsigned
MThreadWork
=
(
MPerBlock
+
MPerThread
-
1
)
/
MPerThread
;
constexpr
unsigned
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
static_assert
(
BlockSize
==
BatchThreadWork
*
MThreadWork
*
NThreadWork
,
"wrong! wrong BlockSize"
);
if
(
DistributeThreadAlongColumnFirst
)
{
// num of operations can be reduced
const
unsigned
b_work_id
=
thread_id
/
(
MThreadWork
*
NThreadWork
);
unsigned
itmp
=
thread_id
-
b_work_id
*
(
MThreadWork
*
NThreadWork
);
const
unsigned
m_work_id
=
itmp
/
NThreadWork
;
const
unsigned
n_work_id
=
itmp
-
m_work_id
*
NThreadWork
;
return
MatrixIndex
{
b_work_id
*
BatchPerThread
,
m_work_id
*
MPerThread
,
n_work_id
*
NPerThread
};
}
else
{
// not implemented
assert
(
false
);
}
}
else
{
// not implemented
assert
(
false
);
}
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
batch_in_c
,
unsigned
m_in_c
,
unsigned
n_in_c
)
{
return
MatrixIndex
{
batch_in_c
,
m_in_c
,
n_in_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// a is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of a, b
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
// loop over batch
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
,
f_accum
);
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
}
if
(
BlockMatrixStrideB
!=
0
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
}
}
// do last batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
,
f_accum
);
}
}
}
};
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideB
,
unsigned
ThreadMatrixStrideC
,
unsigned
BatchSize
,
unsigned
MPerThreadSubC
,
unsigned
NPerThreadSubC
,
unsigned
MLevel0Cluster
,
unsigned
NLevel0Cluster
,
unsigned
MLevel1Cluster
,
unsigned
NLevel1Cluster
,
unsigned
KPerThreadLoop
,
unsigned
BatchPerThread
>
struct
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
unsigned
batch
;
unsigned
row
;
unsigned
col
;
};
__device__
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
()
{
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"wrong! BatchSize is not dividable by BatchPerThread"
);
constexpr
unsigned
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
unsigned
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
static_assert
(
BlockSize
==
BatchThreadWork
*
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
((
MPerThread
%
MPerThreadSubC
==
0
)
&&
(
NPerThread
%
NPerThreadSubC
==
0
),
"wrong! Cannot evenly divide thread work among repeat
\n
"
);
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
((
M
%
MRepeat
==
0
)
&&
(
N
%
NRepeat
==
0
),
"wrong! Cannot evenly divide work among repeat
\n
"
);
constexpr
unsigned
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
unsigned
NPerLevel1Cluster
=
N
/
NRepeat
;
static_assert
((
MPerLevel1Cluster
%
MLevel1Cluster
==
0
)
&&
(
NPerLevel1Cluster
%
NLevel1Cluster
==
0
),
"wrong! Cannot evenly divide work among Level1Cluster
\n
"
);
constexpr
unsigned
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
static_assert
((
MPerLevel0Cluster
%
MLevel0Cluster
==
0
)
&&
(
NPerLevel0Cluster
%
NLevel0Cluster
==
0
),
"wrong! Cannot evenly divide work among Level0Cluster
\n
"
);
static_assert
((
MPerThreadSubC
==
MPerLevel0Cluster
/
MLevel0Cluster
)
&&
(
NPerThreadSubC
==
NPerLevel0Cluster
/
NLevel0Cluster
),
"wrong! thread work size is wrong
\n
"
);
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
const
{
constexpr
unsigned
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
unsigned
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
constexpr
unsigned
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
unsigned
batch_work_id
=
thread_id
/
ThreadPerLevel1Cluster
;
unsigned
cluster_id
=
thread_id
-
batch_work_id
*
ThreadPerLevel1Cluster
;
unsigned
level1_id
=
cluster_id
/
ThreadPerLevel0Cluster
;
unsigned
level1_m_id
=
level1_id
/
NLevel1Cluster
;
unsigned
level1_n_id
=
level1_id
%
NLevel1Cluster
;
unsigned
level0_id
=
cluster_id
%
ThreadPerLevel0Cluster
;
unsigned
level0_m_id
=
level0_id
/
NLevel0Cluster
;
unsigned
level0_n_id
=
level0_id
%
NLevel0Cluster
;
constexpr
unsigned
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
return
MatrixIndex
{
batch_work_id
*
BatchPerThread
,
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
batch_in_c
,
unsigned
m_in_c
,
unsigned
n_in_c
)
{
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
unsigned
m_repeat
=
m_in_c
/
MPerThreadSubC
;
unsigned
n_repeat
=
n_in_c
/
NPerThreadSubC
;
unsigned
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
unsigned
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
return
MatrixIndex
{
batch_in_c
,
m_repeat
*
MPerLevel1Cluster
+
m_in_sub_c
,
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
}
// copy B-sub to form B
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
}
// loop over batch
#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
,
f_accum
);
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
}
}
if
(
BlockMatrixStrideB
!=
0
)
{
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
}
}
}
// do last batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
,
f_accum
);
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_v3
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
//#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_sub_mtx
.
NRow
();
++
i
)
{
#if 1
for
(
unsigned
j
=
0
;
j
<
a_thread_sub_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetA
];
}
#else
static_assert
(
a_thread_sub_mtx
.
NCol
()
==
4
,
"asm only read 4xfp32"
);
#endif
}
}
// copy B-sub to form B
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_sub_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetB
];
}
}
}
// loop over batch
//#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif
1
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]));
}
#endif
}
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_sub_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
];
}
}
}
}
if
(
BlockMatrixStrideB
!=
0
)
{
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_sub_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
];
}
}
}
}
}
// do last batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif
1
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
)
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
;
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]));
}
#endif
}
}
}
template
<
class
BlockMatrixC
,
unsigned
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
{
constexpr
auto
c_block_mtx
=
BlockMatrixC
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
auto
c_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
const
auto
c_thread_mtx_begin
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
c_thread_offset
=
c_thread_mtx_begin
.
batch
*
BlockMatrixStrideC
+
c_block_mtx
.
Get1dIndex
(
c_thread_mtx_begin
.
row
,
c_thread_mtx_begin
.
col
);
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
c_thread_sub_mtx
,
p_c_thread
+
c_thread_sub_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
),
c_block_mtx
,
p_c_block
+
c_block_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
)
+
c_thread_offset
,
c_thread_sub_mtx
.
GetLengths
());
}
}
}
};
src/include/blockwise_gemm.hip.hpp
View file @
f35c64eb
#pragma once
#pragma once
#include "threadwise_gemm.hip.hpp"
#include "threadwise_gemm.hip.hpp"
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
bool
TransA
,
bool
TransB
,
bool
TransC
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideB
,
unsigned
ThreadMatrixStrideC
,
unsigned
BatchSize
,
unsigned
BatchPerThread
,
unsigned
KPerThreadLoop
,
bool
DistributeThreadAlongColumnFirst
>
struct
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
unsigned
batch
;
unsigned
row
;
unsigned
col
;
};
__device__
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
()
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
((
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
((
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col
,
0
));
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! k dimension not consistent!"
);
constexpr
unsigned
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
unsigned
NPerBlock
=
b_block_mtx
.
NCol
();
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
// divide thread work
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"BatchSize % BatchPerThread != 0"
);
static_assert
(
MPerBlock
%
MPerThread
==
0
,
"MPerBlock % MPerThread != 0"
);
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"NPerBlock % NPerThread != 0"
);
constexpr
unsigned
BatchThreadWork
=
(
BatchSize
+
BatchPerThread
-
1
)
/
BatchPerThread
;
constexpr
unsigned
MThreadWork
=
(
MPerBlock
+
MPerThread
-
1
)
/
MPerThread
;
constexpr
unsigned
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
static_assert
(
BlockSize
==
BatchThreadWork
*
MThreadWork
*
NThreadWork
,
"wrong! wrong BlockSize"
);
if
(
DistributeThreadAlongColumnFirst
)
{
// num of operations can be reduced
const
unsigned
b_work_id
=
thread_id
/
(
MThreadWork
*
NThreadWork
);
unsigned
itmp
=
thread_id
-
b_work_id
*
(
MThreadWork
*
NThreadWork
);
const
unsigned
m_work_id
=
itmp
/
NThreadWork
;
const
unsigned
n_work_id
=
itmp
-
m_work_id
*
NThreadWork
;
return
MatrixIndex
{
b_work_id
*
BatchPerThread
,
m_work_id
*
MPerThread
,
n_work_id
*
NPerThread
};
}
else
{
// not implemented
assert
(
false
);
}
}
else
{
// not implemented
assert
(
false
);
}
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
batch_in_c
,
unsigned
m_in_c
,
unsigned
n_in_c
)
{
return
MatrixIndex
{
batch_in_c
,
m_in_c
,
n_in_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// a is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of a, b
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
// loop over batch
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
,
f_accum
);
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
}
if
(
BlockMatrixStrideB
!=
0
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
}
}
// do last batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
,
f_accum
);
}
}
}
};
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideB
,
unsigned
ThreadMatrixStrideC
,
unsigned
BatchSize
,
unsigned
MPerThreadSubC
,
unsigned
NPerThreadSubC
,
unsigned
MLevel0Cluster
,
unsigned
NLevel0Cluster
,
unsigned
MLevel1Cluster
,
unsigned
NLevel1Cluster
,
unsigned
KPerThreadLoop
,
unsigned
BatchPerThread
>
struct
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
unsigned
batch
;
unsigned
row
;
unsigned
col
;
};
__device__
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
()
{
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"wrong! BatchSize is not dividable by BatchPerThread"
);
constexpr
unsigned
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
unsigned
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
static_assert
(
BlockSize
==
BatchThreadWork
*
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
((
MPerThread
%
MPerThreadSubC
==
0
)
&&
(
NPerThread
%
NPerThreadSubC
==
0
),
"wrong! Cannot evenly divide thread work among repeat
\n
"
);
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
((
M
%
MRepeat
==
0
)
&&
(
N
%
NRepeat
==
0
),
"wrong! Cannot evenly divide work among repeat
\n
"
);
constexpr
unsigned
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
unsigned
NPerLevel1Cluster
=
N
/
NRepeat
;
static_assert
((
MPerLevel1Cluster
%
MLevel1Cluster
==
0
)
&&
(
NPerLevel1Cluster
%
NLevel1Cluster
==
0
),
"wrong! Cannot evenly divide work among Level1Cluster
\n
"
);
constexpr
unsigned
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
static_assert
((
MPerLevel0Cluster
%
MLevel0Cluster
==
0
)
&&
(
NPerLevel0Cluster
%
NLevel0Cluster
==
0
),
"wrong! Cannot evenly divide work among Level0Cluster
\n
"
);
static_assert
((
MPerThreadSubC
==
MPerLevel0Cluster
/
MLevel0Cluster
)
&&
(
NPerThreadSubC
==
NPerLevel0Cluster
/
NLevel0Cluster
),
"wrong! thread work size is wrong
\n
"
);
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
const
{
constexpr
unsigned
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
unsigned
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
constexpr
unsigned
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
unsigned
batch_work_id
=
thread_id
/
ThreadPerLevel1Cluster
;
unsigned
cluster_id
=
thread_id
-
batch_work_id
*
ThreadPerLevel1Cluster
;
unsigned
level1_id
=
cluster_id
/
ThreadPerLevel0Cluster
;
unsigned
level1_m_id
=
level1_id
/
NLevel1Cluster
;
unsigned
level1_n_id
=
level1_id
%
NLevel1Cluster
;
unsigned
level0_id
=
cluster_id
%
ThreadPerLevel0Cluster
;
unsigned
level0_m_id
=
level0_id
/
NLevel0Cluster
;
unsigned
level0_n_id
=
level0_id
%
NLevel0Cluster
;
constexpr
unsigned
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
return
MatrixIndex
{
batch_work_id
*
BatchPerThread
,
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
batch_in_c
,
unsigned
m_in_c
,
unsigned
n_in_c
)
{
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
unsigned
m_repeat
=
m_in_c
/
MPerThreadSubC
;
unsigned
n_repeat
=
n_in_c
/
NPerThreadSubC
;
unsigned
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
unsigned
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
return
MatrixIndex
{
batch_in_c
,
m_repeat
*
MPerLevel1Cluster
+
m_in_sub_c
,
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
}
// copy B-sub to form B
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
}
// loop over batch
#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
,
f_accum
);
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
}
}
if
(
BlockMatrixStrideB
!=
0
)
{
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
}
}
}
// do last batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
,
f_accum
);
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_v2
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
//#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetA
];
}
}
}
// copy B-sub to form B
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetB
];
}
}
}
// loop over batch
//#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
c_thread_mtx
.
NCol
();
++
j
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
j
);
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
j
)
+
ib
*
ThreadMatrixStrideC
;
f_accum
(
p_c_thread
[
cindex
],
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
]);
}
}
}
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
];
}
}
}
}
if
(
BlockMatrixStrideB
!=
0
)
{
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
];
}
}
}
}
}
// do last batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
c_thread_mtx
.
NCol
();
++
j
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
j
);
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
j
)
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
;
f_accum
(
p_c_thread
[
cindex
],
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
]);
}
}
}
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_v3
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
//#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetA
];
}
}
}
// copy B-sub to form B
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetB
];
}
}
}
// loop over batch
//#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif
1
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]));
}
#endif
}
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
];
}
}
}
}
if
(
BlockMatrixStrideB
!=
0
)
{
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
];
}
}
}
}
}
// do last batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif
1
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
)
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
;
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]));
}
#endif
}
}
}
template
<
class
BlockMatrixC
,
unsigned
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
{
constexpr
auto
c_block_mtx
=
BlockMatrixC
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
auto
c_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
const
auto
c_thread_mtx_begin
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
c_thread_offset
=
c_thread_mtx_begin
.
batch
*
BlockMatrixStrideC
+
c_block_mtx
.
Get1dIndex
(
c_thread_mtx_begin
.
row
,
c_thread_mtx_begin
.
col
);
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
c_thread_sub_mtx
,
p_c_thread
+
c_thread_sub_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
),
c_block_mtx
,
p_c_block
+
c_block_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
)
+
c_thread_offset
,
c_thread_sub_mtx
.
GetLengths
());
}
}
}
};
template
<
unsigned
BlockSize
,
template
<
unsigned
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
BlockMatrixB
,
...
@@ -1375,6 +420,146 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -1375,6 +420,146 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
}
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
#pragma unroll
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// copy A-sub to form A
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
}
#pragma unroll
// copy B-sub to form B
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
}
// C = A * B
#if 1
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
,
f_accum
);
#else
// inline asm
static_assert
(
c_thread_mtx
.
NRow
()
==
8
&&
c_thread_mtx
.
NCol
()
==
8
,
"asm is only for 8x8"
);
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
// A is transposed
{
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %8, %9
\n
\
v_mac_f32 %1, %8, %10
\n
\
v_mac_f32 %2, %8, %11
\n
\
v_mac_f32 %3, %8, %12
\n
\
v_mac_f32 %4, %8, %13
\n
\
v_mac_f32 %5, %8, %14
\n
\
v_mac_f32 %6, %8, %15
\n
\
v_mac_f32 %7, %8, %16
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
]),
"=v"
(
p_c_thread
[
cindex
+
4
]),
"=v"
(
p_c_thread
[
cindex
+
5
]),
"=v"
(
p_c_thread
[
cindex
+
6
]),
"=v"
(
p_c_thread
[
cindex
+
7
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"v"
(
p_b_thread
[
bindex
+
4
]),
"v"
(
p_b_thread
[
bindex
+
5
]),
"v"
(
p_b_thread
[
bindex
+
6
]),
"v"
(
p_b_thread
[
bindex
+
7
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]),
"4"
(
p_c_thread
[
cindex
+
4
]),
"5"
(
p_c_thread
[
cindex
+
5
]),
"6"
(
p_c_thread
[
cindex
+
6
]),
"7"
(
p_c_thread
[
cindex
+
7
]));
}
}
#endif
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_RegisterDoubleBuffer
(
FloatA
*
const
p_a_block
,
__device__
void
Run_RegisterDoubleBuffer
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatB
*
const
p_b_block
,
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
View file @
f35c64eb
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "blockwise_
batched_
gemm.hip.hpp"
template
<
unsigned
GridSize
,
template
<
unsigned
GridSize
,
unsigned
BlockSize
,
unsigned
BlockSize
,
...
@@ -211,8 +211,6 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
...
@@ -211,8 +211,6 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
{
{
#if 0
#if 0
blockwise_batch_gemm.Run
blockwise_batch_gemm.Run
#elif
0
blockwise_batch_gemm
.
Run_v2
#elif
1
#elif
1
blockwise_batch_gemm
.
Run_v3
blockwise_batch_gemm
.
Run_v3
#endif
#endif
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
f35c64eb
...
@@ -236,15 +236,19 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
...
@@ -236,15 +236,19 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
for
(
unsigned
x
=
0
;
x
<
X
;
++
x
)
for
(
unsigned
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if
1
#if
0
blockwise_gemm.Run
blockwise_gemm.Run
#else
#elif
1
blockwise_gemm
.
Run_asm
#elif 0
blockwise_gemm
.
Run_v2
#elif 0
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
y
*
Wi
+
x
,
p_in_block
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
}
}
}
}
}
}
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
f35c64eb
...
@@ -34,7 +34,11 @@ template <unsigned GridSize,
...
@@ -34,7 +34,11 @@ template <unsigned GridSize,
unsigned
WeiBlockCopyThreadPerDim1
,
unsigned
WeiBlockCopyThreadPerDim1
,
unsigned
InBlockCopyDataPerRead
,
unsigned
InBlockCopyDataPerRead
,
unsigned
WeiBlockCopyDataPerRead
>
unsigned
WeiBlockCopyDataPerRead
>
__global__
void
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
(
__global__
void
#if 0
__launch_bounds__(256,2)
#endif
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
Float
*
const
__restrict__
p_out_global
)
...
@@ -280,15 +284,15 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_b
...
@@ -280,15 +284,15 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_b
for
(
unsigned
x
=
0
;
x
<
X
;
++
x
)
for
(
unsigned
x
=
0
;
x
<
X
;
++
x
)
{
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
#if
1
#if
0
blockwise_gemm.Run
blockwise_gemm.Run
#else
#else
blockwise_gemm
.
Run_RegisterDoubleBuffer
blockwise_gemm
.
Run_RegisterDoubleBuffer
#endif
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
,
p_out_thread
,
f_accum
);
f_accum
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment