Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
913afaeb
Commit
913afaeb
authored
Jan 16, 2019
by
Chao Liu
Browse files
adding implicit gemm
parent
e7b8705b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
328 additions
and
139 deletions
+328
-139
driver/conv.cu
driver/conv.cu
+57
-35
driver/device_implicit_gemm_convolution.cuh
driver/device_implicit_gemm_convolution.cuh
+12
-12
src/include/ConstantMatrixDescriptor.cuh
src/include/ConstantMatrixDescriptor.cuh
+11
-9
src/include/blockwise_tensor_op.cuh
src/include/blockwise_tensor_op.cuh
+14
-0
src/include/gemm.cuh
src/include/gemm.cuh
+217
-74
src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh
src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh
+17
-9
No files found.
driver/conv.cu
View file @
913afaeb
...
...
@@ -16,21 +16,7 @@ struct GeneratorTensor_1
template
<
class
...
Is
>
double
operator
()(
Is
...
is
)
{
#if 0
return double(std::rand()) / double(RAND_MAX);
#elif
1
return
1
;
#elif 0
std
::
initializer_list
<
std
::
size_t
>
ls
=
{
static_cast
<
std
::
size_t
>
(
is
)...};
return
std
::
accumulate
(
ls
.
begin
(),
ls
.
end
(),
std
::
size_t
(
0
));
#else
assert
(
sizeof
...(
Is
)
>
0
);
std
::
initializer_list
<
std
::
size_t
>
ids
=
{
static_cast
<
std
::
size_t
>
(
is
)...};
std
::
vector
<
std
::
size_t
>
lens
(
sizeof
...(
Is
),
100
);
std
::
vector
<
std
::
size_t
>
strides
(
sizeof
...(
Is
),
1
);
std
::
partial_sum
(
lens
.
rbegin
(),
lens
.
rbegin
()
+
(
sizeof
...(
Is
)
-
1
),
strides
.
rbegin
()
+
1
);
return
std
::
inner_product
(
ids
.
begin
(),
ids
.
end
(),
strides
.
begin
(),
std
::
size_t
(
0
))
+
1
;
#endif
}
};
...
...
@@ -46,6 +32,25 @@ struct GeneratorTensor_2
}
};
struct
GeneratorTensor_3
{
template
<
class
...
Is
>
double
operator
()(
Is
...
is
)
{
#if 0
std::initializer_list<std::size_t> ls = {static_cast<std::size_t>(is)...};
return std::accumulate(ls.begin(), ls.end(), std::size_t(0));
#elif
1
assert
(
sizeof
...(
Is
)
>
0
);
std
::
initializer_list
<
std
::
size_t
>
ids
=
{
static_cast
<
std
::
size_t
>
(
is
)...};
std
::
vector
<
std
::
size_t
>
lens
(
sizeof
...(
Is
),
100
);
std
::
vector
<
std
::
size_t
>
strides
(
sizeof
...(
Is
),
1
);
std
::
partial_sum
(
lens
.
rbegin
(),
lens
.
rbegin
()
+
(
sizeof
...(
Is
)
-
1
),
strides
.
rbegin
()
+
1
);
return
std
::
inner_product
(
ids
.
begin
(),
ids
.
end
(),
strides
.
begin
(),
std
::
size_t
(
0
))
+
1
;
#endif
}
};
// this is ugly, only for 4d
template
<
class
TConstTensorDesc
>
void
ostream_ConstantTensorDescriptor
(
TConstTensorDesc
,
std
::
ostream
&
os
=
std
::
cout
)
...
...
@@ -338,7 +343,7 @@ int main()
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif
0
#elif
1
constexpr
unsigned
N
=
1
;
constexpr
unsigned
C
=
1
;
constexpr
unsigned
HI
=
34
;
...
...
@@ -347,21 +352,21 @@ int main()
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
#elif 1
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
34
;
constexpr
unsigned
WI
=
34
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
#elif 0
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
64
;
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
64
;
constexpr
unsigned
HI
=
56
;
constexpr
unsigned
WI
=
56
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
#elif 0
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
64
;
...
...
@@ -374,34 +379,51 @@ int main()
auto
in_nchw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcsr_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
,
S
,
R
>
{});
auto
wei_srck_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
S
,
R
,
C
,
K
>
{});
auto
out_nkhw_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_desc
,
wei_kcsr_desc
);
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcsr_desc
,
std
::
cout
<<
"wei_kcsr_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_srck_desc
,
std
::
cout
<<
"wei_srck_desc: "
);
ostream_ConstantTensorDescriptor
(
out_nkhw_desc
,
std
::
cout
<<
"out_nkhw_desc: "
);
Tensor
<
float
>
in_nchw
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
float
>
wei_kcsr
(
make_TensorDescriptor
(
wei_kcsr_desc
));
Tensor
<
float
>
wei_srck
(
make_TensorDescriptor
(
wei_srck_desc
));
Tensor
<
float
>
out_nkhw_host
(
make_TensorDescriptor
(
out_nkhw_desc
));
Tensor
<
float
>
out_nkhw_device
(
make_TensorDescriptor
(
out_nkhw_desc
));
#if 0
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_srck.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif
1
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
#elif
0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcsr
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_srck
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#elif 0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcsr
.
GenerateTensorValue
(
GeneratorTensor_3
{},
num_thread
);
#elif 1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_3
{},
num_thread
);
wei_kcsr
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#endif
#if 1
auto
wei_srck_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
S
,
R
,
C
,
K
>
{});
Tensor
<
float
>
wei_srck
(
make_TensorDescriptor
(
wei_srck_desc
));
auto
f_reorder_kcsr2srck
=
[
&
](
auto
k
,
auto
c
,
auto
s
,
auto
r
)
{
wei_srck
(
s
,
r
,
c
,
k
)
=
wei_kcsr
(
k
,
c
,
s
,
r
);
};
make_ParallelTensorFunctor
(
f_reorder_kcsr2srck
,
K
,
C
,
S
,
R
)(
num_thread
);
#endif
for
(
int
i
=
0
;
i
<
40
;
++
i
)
#if 0
wei_srck.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out_nkhw_device.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#endif
for
(
int
i
=
0
;
i
<
1
;
++
i
)
{
#if 0
device_direct_convolution_1(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
...
...
@@ -428,7 +450,7 @@ int main()
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#endif
#if
0
#if
1
LogRange
(
std
::
cout
<<
"in_nchw : "
,
in_nchw
.
mData
,
","
)
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"wei_kcsr: "
,
wei_kcsr
.
mData
,
","
)
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"out_nkhw_host : "
,
out_nkhw_host
.
mData
,
","
)
<<
std
::
endl
;
...
...
driver/device_implicit_gemm_convolution.cuh
View file @
913afaeb
#pragma once
#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh"
//
#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh"
#include "gridwise_implicit_gemm_convolution_nchw_srck.cuh"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
...
...
@@ -26,20 +26,20 @@ void device_implicit_gemm_convolution(
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
#if
0
constexpr unsigned NPerBlock =
2
;
constexpr unsigned KPerBlock =
64
;
constexpr unsigned CPerBlock =
4
;
#if
1
constexpr
unsigned
NPerBlock
=
1
;
constexpr
unsigned
KPerBlock
=
1
;
constexpr
unsigned
CPerBlock
=
1
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr unsigned NPerThread =
2
;
constexpr unsigned KPerThread =
8
;
constexpr unsigned CPerThread =
2
;
constexpr unsigned HoPerThread =
1
;
constexpr unsigned WoPerThread =
4
;
constexpr
unsigned
NPerThread
=
1
;
constexpr
unsigned
KPerThread
=
1
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr unsigned BlockSize =
25
6;
constexpr
unsigned
BlockSize
=
1
6
;
#elif 1
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
...
...
@@ -50,7 +50,7 @@ void device_implicit_gemm_convolution(
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
BlockSize
=
128
;
...
...
src/include/ConstantMatrixDescriptor.cuh
View file @
913afaeb
#pragma once
#include "common.cuh"
template
<
unsigned
NRow
,
unsigned
NCol
,
unsigned
RowStride
>
template
<
unsigned
NRow
_
,
unsigned
NCol
_
,
unsigned
RowStride
_
>
struct
ConstantMatrixDescriptor
{
__host__
__device__
ConstantMatrixDescriptor
()
{
static_assert
(
NCol
<=
RowStride
,
"wrong! NCol > RowStride!"
);
static_assert
(
NCol
_
<=
RowStride
_
,
"wrong! NCol > RowStride!"
);
}
__host__
__device__
constexpr
unsigned
GetNumberOfRow
()
const
{
return
NRow
;
}
__host__
__device__
constexpr
unsigned
NRow
()
const
{
return
NRow_
;
}
__host__
__device__
constexpr
unsigned
NCol
()
const
{
return
NCol_
;
}
__host__
__device__
constexpr
unsigned
GetNumberOfColumn
()
const
{
return
NCol
;
}
__host__
__device__
constexpr
unsigned
RowStride
()
const
{
return
RowStride_
;
}
__host__
__device__
constexpr
unsigned
GetRowStride
()
const
{
return
RowStride
;
}
__host__
__device__
constexpr
auto
GetLengths
()
const
{
return
Sequence
<
NRow_
,
NCol_
>
{}
;
}
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
{
return
NRow
*
NCol
;
}
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
{
return
NRow
_
*
NCol
_
;
}
__host__
__device__
constexpr
unsigned
GetElementSpace
()
const
{
return
NRow
*
RowStride
;
}
__host__
__device__
constexpr
unsigned
GetElementSpace
()
const
{
return
NRow
_
*
RowStride
_
;
}
__host__
__device__
unsigned
Get1dIndex
(
unsigned
irow
,
unsigned
icol
)
const
{
return
irow
*
RowStride
+
icol
;
return
irow
*
RowStride
_
+
icol
;
}
template
<
unsigned
SubNRow
,
unsigned
SubNCol
>
__host__
__device__
constexpr
auto
MakeSubMatrixDescriptor
(
Number
<
SubNRow
>
,
Number
<
SubNCol
>
)
const
{
return
ConstantMatrixDescriptor
<
SubNRow
,
SubNCol
,
RowStride
>
{};
return
ConstantMatrixDescriptor
<
SubNRow
,
SubNCol
,
RowStride
_
>
{};
}
};
...
...
src/include/blockwise_tensor_op.cuh
View file @
913afaeb
...
...
@@ -135,6 +135,20 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
#if 1
printf
(
"did %u %u %u %u, did_IR %u %u %u %u, index %u %u
\n
"
,
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
],
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
],
aindex
,
bindex
);
#endif
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
}
...
...
src/include/gemm.cuh
View file @
913afaeb
#pragma once
template
<
class
ThreadMatrixA
,
class
ThreadMatrixB
,
class
ThreadMatrixC
,
template
<
class
Float
,
class
SrcMatrix
,
class
DstMatrix
,
unsigned
NRow
,
unsigned
NCol
>
__device__
void
threadwise_matrix_copy
(
SrcMatrix
,
Float
*
const
p_src
,
DstMatrix
,
Float
*
p_dst
,
Sequence
<
NRow
,
NCol
>
)
{
const
auto
src_mtx
=
SrcMatrix
{};
// constexpr doesn't compile
const
auto
dst_mtx
=
DstMatrix
{};
// constexpr doesn't compile
for
(
unsigned
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
NCol
;
++
j
)
{
const
unsigned
src_index
=
src_mtx
.
Get1dIndex
(
i
,
j
);
const
unsigned
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
j
);
p_dst
[
dst_index
]
=
p_src
[
src_index
];
}
}
}
template
<
class
MatrixA
,
class
MatrixB
,
class
MatrixC
,
bool
TransA
,
bool
TransB
,
bool
TransC
,
...
...
@@ -10,18 +29,47 @@ template <class ThreadMatrixA,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
threadwise_gemm
(
Thread
MatrixA
,
__device__
void
threadwise_gemm
(
MatrixA
,
Constant
<
bool
,
TransA
>
,
FloatA
*
const
p_a_thread
,
Thread
MatrixB
,
MatrixB
,
Constant
<
bool
,
TransB
>
,
FloatB
*
const
p_b_thread
,
Thread
MatrixC
,
MatrixC
,
Constant
<
bool
,
TransC
>
,
FloatC
*
p_c_thread
,
Accumulator
)
Accumulator
f_accum
)
{
// do something
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
const
auto
a_mtx
=
MatrixA
{};
// constexpr doesn't compile
const
auto
b_mtx
=
MatrixB
{};
// constexpr doesn't compile
const
auto
c_mtx
=
MatrixC
{};
// constexpr doesn't compile
constexpr
unsigned
M
=
c_mtx
.
NRow
();
constexpr
unsigned
N
=
c_mtx
.
NCol
();
constexpr
unsigned
K
=
a_mtx
.
NRow
();
// A is transposed
for
(
unsigned
i
=
0
;
i
<
M
;
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
N
;
++
j
)
{
for
(
unsigned
k
=
0
;
k
<
K
;
++
k
)
{
const
unsigned
aindex
=
a_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
bindex
=
b_mtx
.
Get1dIndex
(
k
,
j
);
const
unsigned
cindex
=
c_mtx
.
Get1dIndex
(
i
,
j
);
f_accum
(
p_c_thread
[
cindex
],
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
]);
}
}
}
}
else
{
// not implemented
assert
(
false
);
}
}
template
<
unsigned
BlockSize
,
...
...
@@ -36,8 +84,8 @@ template <unsigned BlockSize,
unsigned
ThreadMatrixStrideC
,
unsigned
BatchSize
,
unsigned
BatchPerThread
,
unsigned
KPerLoop
,
class
Accumulator
>
unsigned
KPer
Thread
Loop
,
bool
DistributeThreadAlongColumnFirst
>
struct
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
{
unsigned
mMyThreadOffsetA
=
0
;
...
...
@@ -52,82 +100,177 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
__device__
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
()
{
static_assert
(
ThreadMatrixStrideC
>
0
,
"wrong! ThreadMatrixStrideC == 0!"
);
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
#if 0
constexpr auto a_block_desc = BlockMatrixA{};
constexpr auto b_block_desc = BlockMatrixB{};
const
auto
c_thread_mtx_index
=
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
constexpr unsigned a_thread_row = (!TransA) ? MPerThread : KPerThread;
constexpr unsigned a_thread_col = (!TransA) ? KPerThread : MPerThread;
constexpr unsigned b_thread_row = (!TransB) ? KPerThread : NPerThread;
constexpr unsigned b_thread_col = (!TransB) ? NPerThread : KPerThread;
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch_begin
*
a_block_mtx
.
GetElementSpace
()
+
((
!
TransA
)
?
a_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
row_begin
,
0
)
:
a_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
row_begin
));
constexpr auto a_thread_desc = ConstantMatrixDescriptor<a_thread_row, a_thread_col>{};
constexpr auto b_thread_desc = ConstantMatrixDescriptor<b_thread_row, b_thread_col>{};
constexpr auto c_thread_desc = ConstantMatrixDescriptor<MPerThread, NPerThread>{};
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch_begin
*
b_block_mtx
.
GetElementSpace
()
+
((
!
TransB
)
?
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col_begin
)
:
b_block_mtx
.
Get1dIndex
(
c_thread_mtx_index
.
col_begin
,
0
));
}
constexpr unsigned m_block = (!TransA) ? a_block_desc.NRow() : a_block_desc.NCol();
constexpr unsigned n_block = (!TransB) ? b_block_desc.NCol() : b_block_desc.NRow();
__device__
MatrixIndex
CalculateThreadMatrixCIndex
(
unsigned
thread_id
)
const
{
constexpr unsigned m_thread = (!TransA) ? a_thread_desc.NRow() : a_thread_desc.NCol();
constexpr unsigned n_thread = (!TransB) ? b_thread_desc.NCol() : b_thread_desc.NRow();
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread;
constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread;
constexpr unsigned num_threads_per_batch = num_threads_per_row * num_threads_per_col;
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! k dimension not consistent!"
);
static_assert(BlockSize >= ((BatchSize + BatchPerThread - 1) / BatchPerThread) *
num_threads_per_batch,
"not enough thread!");
constexpr
unsigned
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
unsigned
NPerBlock
=
b_block_mtx
.
NCol
();
const auto
mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id());
const
auto
c_thread_mtx
=
ThreadMatrixC
{};
// constexpr doesn't compile
// mMyThreadOffsetA = xxx;
// mMyThreadoffSetB = xxx;
#else
mMyThreadOffsetA
=
0
;
mMyThreadOffsetB
=
0
;
#endif
}
// divide thread work
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
__device__
MatrixIndex
CalculateThreadMatrixCIndex
(
unsigned
thread_id
)
const
{
#if 0
constexpr auto a_block = BlockMatrixA{};
constexpr auto b_block = BlockMatrixB{};
constexpr auto c_block = BlockMatrixC{};
constexpr auto a_thread = ThreadMatrixA{};
constexpr auto b_thread = ThreadMatrixB{};
constexpr auto c_thread = ThreadMatrixC{};
constexpr unsigned m_block = (!TransA) ? a_block.NRow() : a_block.NCol();
constexpr unsigned n_block = (!TransB) ? b_block.NCol() : b_block.NRow();
constexpr unsigned m_thread = (!TransA) ? a_thread.NRow() : a_thread.NCol();
constexpr unsigned n_thread = (!TransB) ? b_thread.NCol() : b_thread.NRow();
constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread;
constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread;
constexpr unsigned num_threads_per_batch = num_threads_per_row * num_threads_per_col;
// this is wrong, need fix
const unsigned batch_begin = thread_id / (num_threads_per_batch)*BatchPerThread;
const unsigned tmp = thread_id - batch_id * (num_threads_per_row * num_threads_per_col);
const unsigned thread_matrix_row_id = tmp / num_threads_per_row;
const unsigned thread_matrix_col_id = tmp - thread_matrix_row_id * num_threads_per_row;
return MatrixIndex{
batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread};
#else
return
MatrixIndex
{
0
,
0
,
0
};
#endif
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"BatchSize % BatchPerThread != 0"
);
static_assert
(
MPerBlock
%
MPerThread
==
0
,
"MPerBlock % MPerThread != 0"
);
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"NPerBlock % NPerThread != 0"
);
constexpr
unsigned
BThreadWork
=
(
BatchSize
+
BatchPerThread
-
1
)
/
BatchPerThread
;
constexpr
unsigned
MThreadWork
=
(
MPerBlock
+
MPerThread
-
1
)
/
MPerThread
;
constexpr
unsigned
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
static_assert
(
BlockSize
==
BThreadWork
*
MThreadWork
*
NThreadWork
,
"wrong! wrong BlockSize"
);
// printf("%u %u, %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), MThreadWork,
// NThreadWork);
if
(
DistributeThreadAlongColumnFirst
)
{
// num of operations can be reduced
const
unsigned
b_work_id
=
thread_id
/
(
MThreadWork
*
NThreadWork
);
unsigned
itmp
=
thread_id
-
b_work_id
*
(
MThreadWork
*
NThreadWork
);
const
unsigned
m_work_id
=
itmp
/
NThreadWork
;
const
unsigned
n_work_id
=
itmp
-
m_work_id
*
NThreadWork
;
return
MatrixIndex
{
b_work_id
*
BatchPerThread
,
m_work_id
*
MPerThread
,
n_work_id
*
NPerThread
};
}
else
{
// not implemented
assert
(
false
);
}
}
else
{
// not implemented
assert
(
false
);
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
)
const
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
run
(
FloatA
*
const
p_a_block
,
FloatB
*
const
p_b_block
,
FloatC
*
p_c_thread
,
Accumulator
f_accum
)
const
{
// do something
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
{
constexpr
auto
True
=
Constant
<
bool
,
true
>
{};
constexpr
auto
False
=
Constant
<
bool
,
false
>
{};
const
auto
a_block_mtx
=
BlockMatrixA
{};
// constexpr doesn't compile
const
auto
b_block_mtx
=
BlockMatrixB
{};
// constexpr doesn't compile
const
auto
c_thread_mtx
=
ThreadMatrixC
{};
// constexpr doesn't compile
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// a is transposed, b is not
const
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
// constexpr doesn't compile
const
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// constexpr doesn't compile
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of a, b
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
// loop over batch
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
,
f_accum
);
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
+
k_begin
*
a_block_mtx
.
RowStride
(),
a_thread_mtx
,
p_a_thread
,
a_thread_mtx
.
GetLengths
());
}
if
(
BlockMatrixStrideB
!=
0
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
k_begin
*
b_block_mtx
.
RowStride
(),
b_thread_mtx
,
p_b_thread
,
b_thread_mtx
.
GetLengths
());
}
}
// do last batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
,
f_accum
);
}
}
}
};
src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh
View file @
913afaeb
...
...
@@ -90,13 +90,12 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
constexpr
auto
out_hkwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
#if
0
#if
1
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
print_ConstantTensorDescriptor
(
in_nchw_block_desc
,
"in_nchw_block_desc"
);
print_ConstantTensorDescriptor
(
in_chwn_block_desc
,
"in_chwn_block_desc"
);
print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc");
print_ConstantTensorDescriptor
(
wei_srck_block_desc
,
"wei_srck_block_desc"
);
print_ConstantTensorDescriptor
(
out_hkwn_thread_desc
,
"out_hkwn_thread_desc"
);
...
...
@@ -120,8 +119,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
const
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
// constexpr doesn't compile
auto
f_accum
=
[](
auto
&
c
,
auto
&
ab
)
{
c
+=
ab
;
};
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
...
...
@@ -133,11 +130,11 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
I
1
),
I
0
),
HoPerBlock
,
HoPerThread
,
CPerThread
,
decltype
(
f_accum
)
>
{};
true
>
{};
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
...
...
@@ -183,24 +180,29 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
__syncthreads
();
#if 1
// a series of batched GEMM
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
{
for
(
unsigned
r
=
0
;
r
<
R
;
++
r
)
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_batch_gemm
.
run
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
0
,
r
,
0
),
p_out_thread
);
p_out_thread
,
f_accum
);
}
}
#endif
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
col
_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
row
_begin
/
NPerThread
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row
_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col
_begin
/
NPerThread
;
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
...
...
@@ -216,4 +218,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
wo_block_data_begin
+
wo_thread_data_begin
),
out_hkwn_thread_desc
.
GetLengths
(),
reorder_nkhw_from_hkwn
);
// printf("%f %f %f %f\n", p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
// printf("%u %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(),
// matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin); printf("%u
// %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), ho_thread_data_begin,
// k_thread_data_begin, wo_thread_data_begin);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment