Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0f6cb787
Commit
0f6cb787
authored
Mar 02, 2020
by
Chao Liu
Browse files
double index buffer
parent
7d09790a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
378 additions
and
52 deletions
+378
-52
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+362
-37
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+5
-4
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+11
-11
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
0f6cb787
...
@@ -77,10 +77,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -77,10 +77,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
return
2
*
(
a_block_space
+
b_block_space
)
*
sizeof
(
Float
);
return
2
*
(
a_block_space
+
b_block_space
)
*
sizeof
(
Float
);
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
__device__
void
Run
_single_slice_window
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
)
const
Float
*
__restrict__
p_shared_block
)
const
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
...
@@ -125,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -125,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_block
wise
_copy
=
using
a_block_copy
_type
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
),
...
@@ -142,8 +142,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -142,8 +142,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace
::
Global
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
;
{
0
,
m_block_data_on_global
},
{
0
,
0
});
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -151,7 +150,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -151,7 +150,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_block
wise
_copy
=
using
b_block_copy
_type
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
),
...
@@ -168,8 +167,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -168,8 +167,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace
::
Global
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
InMemoryDataOperation
::
Set
>
;
{
0
,
n_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -225,14 +223,27 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -225,14 +223,27 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// zero out threadwise output
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
// prepare blockwise copy slicing window
auto
a_block_copy_0
=
a_block_copy_type
({
0
,
m_block_data_on_global
},
{
0
,
0
});
auto
b_block_copy_0
=
b_block_copy_type
({
0
,
n_block_data_on_global
},
{
0
,
0
});
auto
a_block_copy_1
=
a_block_copy_0
;
auto
b_block_copy_1
=
b_block_copy_0
;
a_block_copy_1
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
b_block_copy_1
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
constexpr
auto
a_block_slice_copy_steps
=
Sequence
<
2
*
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_steps
=
Sequence
<
2
*
KPerBlock
,
0
>
{};
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
a_block_copy_0
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
b_block_copy_0
.
Run
(
p_b_global
,
p_b_block_double
);
}
constexpr
auto
a_block_slice_copy_steps
=
Sequence
<
KPerBlock
,
0
>
{};
a_block_copy_0
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
constexpr
auto
b_block_slice_copy_steps
=
Sequence
<
KPerBlock
,
0
>
{};
b_block_copy_0
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
}
// LDS double buffer: main body
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
...
@@ -253,24 +264,30 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -253,24 +264,30 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float
*
p_b_block_next
=
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
auto
&
a_block_copy_now
=
even_loop
?
a_block_copy_0
:
a_block_copy_1
;
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
auto
&
b_block_copy_now
=
even_loop
?
b_block_copy_0
:
b_block_copy_1
;
auto
&
a_block_copy_next
=
even_loop
?
a_block_copy_1
:
a_block_copy_0
;
auto
&
b_block_copy_next
=
even_loop
?
b_block_copy_1
:
b_block_copy_0
;
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
)
;
Float
p_a_thread_buffer
[
a_block_copy_type
::
GetThreadBufferSize
()]
;
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
)
;
Float
p_b_thread_buffer
[
b_block_copy_type
::
GetThreadBufferSize
()]
;
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
a_block_copy_next
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
b_block_copy_next
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
a_block_copy_next
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
b_block_copy_next
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_block
wise
_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
a_block_copy
_next
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_block
wise
_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
b_block_copy
_next
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
}
}
...
@@ -280,26 +297,23 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -280,26 +297,23 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
if
(
has_two_iteration_left
)
// if has 2 iteration left
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
{
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_a_thread_buffer
[
a_block_copy_type
::
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_block_copy_type
::
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_block
wise
_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
a_block_copy
_1
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_block
wise
_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
b_block_copy
_1
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on
even (
2nd-last
)
data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store
odd (
last
)
data to LDS
a_block
wise
_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
a_block_copy
_1
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
p_a_block_double
+
a_block_space
);
b_block
wise
_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
b_block_copy
_1
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
p_b_block_double
+
b_block_space
);
__syncthreads
();
__syncthreads
();
...
@@ -311,7 +325,314 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -311,7 +325,314 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
{
{
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on even (last) data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
// input: register to global memory
{
constexpr
index_t
M1
=
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
M0
=
M
/
M1
;
constexpr
index_t
N1
=
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
N0
=
N
/
N1
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
MPerThread
,
GemmNRepeat
,
NPerThread
>
{});
constexpr
auto
c_m0_m1_n0_n1_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
M0
,
M1
>>
{},
UnMerge
<
Sequence
<
N0
,
N1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
CThreadCopySrcDstAccessOrder
,
CThreadCopySrcDstVectorReadWriteDim
,
1
,
CThreadCopyDstDataPerWrite
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
>
(
{
0
,
0
,
0
,
0
},
{
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
})
.
Run
(
p_c_thread
,
p_c_global
);
}
}
__device__
void
Run_double_slice_window
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
a_k_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_k_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
K
=
a_k_m_global_desc
.
GetLengths
()[
0
];
constexpr
auto
M
=
a_k_m_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
b_k_n_global_desc
.
GetLengths
()[
1
];
// don't do anything if K == 0
if
(
K
==
0
)
{
return
;
}
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
);
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
m_block_data_on_global
=
block_work_id
[
0
]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
1
]
*
NPerBlock
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// A matrix blockwise copy
using
a_block_copy_type
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K_M
,
ABlockCopyThreadClusterLengths_K_M
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
Sequence
<
0
,
1
>
,
ABlockCopySrcVectorReadDim
,
1
,
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_M
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
;
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix blockwise copy
using
b_block_copy_type
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K_N
,
BBlockCopyThreadClusterLengths_K_N
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
Sequence
<
0
,
1
>
,
BBlockCopySrcVectorReadDim
,
1
,
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_N
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
;
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
a_k_m_block_desc
);
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
b_k_n_block_desc
);
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_m0m1_n0n1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
MPerThread
>
{},
Number
<
GemmNRepeat
*
NPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
decltype
(
c_m0m1_n0n1_thread_mtx_desc
),
MPerThread
,
NPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
KPerThread
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space
;
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
// prepare blockwise copy slicing window
auto
a_block_copy_0
=
a_block_copy_type
({
0
,
m_block_data_on_global
},
{
0
,
0
});
auto
b_block_copy_0
=
b_block_copy_type
({
0
,
n_block_data_on_global
},
{
0
,
0
});
#if 0
auto a_block_copy_1 = a_block_copy_0;
auto b_block_copy_1 = b_block_copy_0;
a_block_copy_1.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_block_copy_1.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
#else
auto
a_block_copy_1
=
a_block_copy_type
({
KPerBlock
,
m_block_data_on_global
},
{
0
,
0
});
auto
b_block_copy_1
=
b_block_copy_type
({
KPerBlock
,
n_block_data_on_global
},
{
0
,
0
});
#endif
constexpr
auto
a_block_slice_copy_steps
=
Sequence
<
2
*
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_steps
=
Sequence
<
2
*
KPerBlock
,
0
>
{};
// LDS double buffer: preload data into LDS
{
a_block_copy_0
.
Run
(
p_a_global
,
p_a_block_double
);
b_block_copy_0
.
Run
(
p_b_global
,
p_b_block_double
);
a_block_copy_0
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
b_block_copy_0
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
}
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
auto
&
a_block_copy_now
=
even_loop
?
a_block_copy_0
:
a_block_copy_1
;
auto
&
b_block_copy_now
=
even_loop
?
b_block_copy_0
:
b_block_copy_1
;
auto
&
a_block_copy_next
=
even_loop
?
a_block_copy_1
:
a_block_copy_0
;
auto
&
b_block_copy_next
=
even_loop
?
b_block_copy_1
:
b_block_copy_0
;
Float
p_a_thread_buffer
[
a_block_copy_type
::
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_block_copy_type
::
GetThreadBufferSize
()];
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_block_copy_next
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_block_copy_next
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
a_block_copy_next
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
b_block_copy_next
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_block_copy_next
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_block_copy_next
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_a_thread_buffer
[
a_block_copy_type
::
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_block_copy_type
::
GetThreadBufferSize
()];
__syncthreads
();
// LDS double buffer: load last data from device mem
a_block_copy_1
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_block_copy_1
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on even (2nd-last) data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store odd (last) data to LDS
a_block_copy_1
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_block_copy_1
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space
,
p_b_block_double
+
b_block_space
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on even (last) data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
}
}
...
@@ -373,7 +694,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -373,7 +694,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
__shared__
Float
p_shared_block
[
shared_block_size
];
__shared__
Float
p_shared_block
[
shared_block_size
];
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_block
);
#if 1
Run_single_slice_window
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_block
);
#else
Run_double_slice_window
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_block
);
#endif
}
}
};
};
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
0f6cb787
...
@@ -187,13 +187,14 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -187,13 +187,14 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
...
@@ -237,11 +238,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -237,11 +238,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmKPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
GemmKPerThreadLoop
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
,
ThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
...
...
driver/src/conv_driver.cpp
View file @
0f6cb787
...
@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
...
@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if
1
#if
0
// 1x1
// 1x1
constexpr index_t N = 64;
constexpr index_t N = 64;
constexpr
index_t
C
=
64
;
constexpr index_t C =
512
;
constexpr
index_t
HI
=
56
;
constexpr index_t HI =
28
;
constexpr
index_t
WI
=
56
;
constexpr index_t WI =
28
;
constexpr
index_t
K
=
256
;
constexpr index_t K =
1024
;
constexpr index_t Y = 1;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t X = 1;
...
@@ -44,13 +44,13 @@ int main(int argc, char* argv[])
...
@@ -44,13 +44,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
#elif
1
// 1x7
// 1x7
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
constexpr
index_t
X
=
7
;
...
@@ -281,7 +281,7 @@ int main(int argc, char* argv[])
...
@@ -281,7 +281,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -327,7 +327,7 @@ int main(int argc, char* argv[])
...
@@ -327,7 +327,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif
1
#elif
0
// 7x1 filter, 3x0 pad, 17x17 input
// 7x1 filter, 3x0 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -434,7 +434,7 @@ int main(int argc, char* argv[])
...
@@ -434,7 +434,7 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
nrepeat
);
nrepeat
);
#elif
0
#elif
1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment