Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2090160a
Commit
2090160a
authored
Aug 05, 2019
by
Jehandad Khan
Browse files
tuning params and driver changes for wrw
parent
2eeeb176
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
86 additions
and
21 deletions
+86
-21
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+18
-10
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+63
-8
driver/src/driver.cpp
driver/src/driver.cpp
+5
-3
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
2090160a
...
@@ -175,8 +175,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -175,8 +175,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// weight tensor
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
// JD: we can no longer unfold because the tensor will not be contiguus for wrw
wei_k_c_y_x_global_desc
.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
// constexpr auto wei_e_k_global_desc =
// wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
constexpr
auto
wei_e_k_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
wei_k_c_y_x_global_desc
,
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
0
>
{});
// tensor descriptor in LDS, dst of blockwise copy
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -190,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -190,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
auto
blockwise_wei_copy
=
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_global_
merged_
desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopySubLengths_E_K
,
...
@@ -295,7 +298,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -295,7 +298,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
blockwise_wei_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
__syncthreads
();
__syncthreads
();
...
@@ -309,9 +312,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -309,9 +312,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
p_wei_block_next
);
}
}
}
}
...
@@ -322,13 +325,15 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -322,13 +325,15 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// even iteration
// even iteration
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
p_wei_block_on_global
+=
EPerBlock
*
wei_e_k_global_desc
.
GetStride
(
I0
);
blockwise_wei_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
// JD, the following no longer applies since the tensor is a merged tesnor ( not contiguous)
// p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads
();
__syncthreads
();
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
...
@@ -338,7 +343,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -338,7 +343,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
// odd iteration
// odd iteration
__syncthreads
();
__syncthreads
();
...
@@ -395,7 +400,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -395,7 +400,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
p_out_global
+
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
#if 1
threadwise_generic_tensor_slice_copy_v1
(
threadwise_generic_tensor_slice_copy_v1
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
p_out_thread
,
p_out_thread
,
...
@@ -406,6 +411,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -406,6 +411,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
{},
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
{},
Number
<
1
>
{});
Number
<
1
>
{});
#elif 0
p_out_global
[
0
]
=
p_out_thread
[
0
];
#endif
}
}
}
}
};
};
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
2090160a
...
@@ -13,6 +13,9 @@ template <class T,
...
@@ -13,6 +13,9 @@ template <class T,
class
ConvStrides
,
class
ConvStrides
,
class
ConvDilations
>
class
ConvDilations
>
void
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
InDesc
,
void
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
InDesc
,
// the input desc needs to be reordered for wrw : cnhw would be the new order
// the forward kernel always assumes red on the second dim and this would make it reduce on the n dimension due to the switchibng we did
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
...
@@ -54,12 +57,60 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -54,12 +57,60 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
N1
=
2
;
constexpr
index_t
N1
=
2
;
constexpr
index_t
N2
=
4
;
constexpr
index_t
N2
=
4
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
#if 1
// JD: New params for wrw
// each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
//~ Must be equal to N2 above for implicit_gemm_v4r1_lds
// The fgollowign is not a tuning partam, thisnis related to the block size and the size of the gemm
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
// how much E do we read for each thread, since in GEMM K is the reduction dimensikon not the output channels in convolution
constexpr
index_t
GemmKPerThreadLoop
=
1
;
// The width of the vector load for GEMM, for mat A and B
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
// changed from 4
// copy related (input)
// this is a general tensor API for any number of dims and in this case thje tensor has 4 dims
// sublength times cluster length shoudl equal blockwise slice length
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
// the arrange order describes which dim changes the thread id quickest
// also important for accessing contiguous memory
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
// the implication of this is only performance not correctness
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
// width of the vector read/write if we wish to read float1 ,float2/4
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
1
;
// copy related (weights)
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
// width of the vector read/write if we wish to read float1 ,float2/4
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 0
//JD: Original params from chao
// each thread hold 64 data
// each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -147,11 +198,15 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -147,11 +198,15 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
<
GridSize
,
<
GridSize
,
BlockSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
// pass in rthe reordered desc
decltype
(
in_nchw_desc
.
ReorderGivenNew2Old
(
Sequence
<
1
,
0
,
2
,
3
>
{})),
decltype
(
out_nkhw_desc
.
ReorderGivenNew2Old
(
Sequence
<
1
,
0
,
2
,
3
>
{})),
// pass in the output instead of the weight, also reordered to knhw
decltype
(
wei_kcyx_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
// the output would be the weights, which would not be reordered
ConvStrides
,
// as discussed in the morning for wrw strides and dilation switch positions
ConvDilations
,
ConvDilations
,
// wrw: becomes stride
ConvStrides
,
// wrw: becomes dilation
BPerBlock
,
BPerBlock
,
KPerBlock
,
KPerBlock
,
EPerBlock
,
EPerBlock
,
...
@@ -186,8 +241,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -186,8 +241,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
time
,
...
...
driver/src/driver.cpp
View file @
2090160a
...
@@ -71,8 +71,8 @@ int main(int argc, char* argv[])
...
@@ -71,8 +71,8 @@ int main(int argc, char* argv[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if
0
#if
1
constexpr index_t N =
64
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
WI
=
8
;
...
@@ -533,6 +533,8 @@ int main(int argc, char* argv[])
...
@@ -533,6 +533,8 @@ int main(int argc, char* argv[])
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif 1
#elif 1
// this is the same as MIOpen
// I should modify this one
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
@@ -562,7 +564,7 @@ int main(int argc, char* argv[])
...
@@ -562,7 +564,7 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment