Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f05b210a
Commit
f05b210a
authored
Jun 15, 2019
by
Jing Zhang
Browse files
finished merge
parent
b3108646
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
61 deletions
+67
-61
driver/driver.hip.cpp
driver/driver.hip.cpp
+39
-52
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
...implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
+28
-9
No files found.
driver/driver.hip.cpp
View file @
f05b210a
...
...
@@ -499,7 +499,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HDilation
=
1
;
constexpr
index_t
WDilation
=
1
;
constexpr
index_t
Direction
=
2
;
// 1: Forward; 0:Backward
constexpr
index_t
Direction
=
1
;
// 1: Forward; 0:Backward
#if 0
constexpr index_t N = 32;
constexpr index_t C = 128;
...
...
@@ -550,10 +550,10 @@ int main(int argc, char* argv[])
#elif 1
// 1x1 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
256
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
...
...
@@ -721,7 +721,7 @@ int main(int argc, char* argv[])
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_3
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_
0
{
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_
2
{
-
5
,
5
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#elif 0
...
...
@@ -734,49 +734,33 @@ int main(int argc, char* argv[])
#endif
}
#if 1
#if 0
device_direct_convolution_1
#elif
0
device_convolution_direct_v2_nchw_kcyx_nkhw
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#elif 1
device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw
#endif
(
in_nchw_desc
,
in_nchw_device
,
if
(
Direction
==
1
)
{
device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
strides
,
dilations
,
Number
<
Direction
>
{},
out_nkhw
,
out_nkhw
_device
,
nrepeat
);
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
in_nchw_desc
,
in_nchw
,
}
else
{
device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw
(
in_nchw_desc
,
in_nchw_device
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
lower_pads
,
upper_pads
,
strides
,
dilations
,
Number
<
Direction
>
{},
out_nkhw
,
nrepeat
);
#endif
}
if
(
do_verification
)
{
#if 0
...
...
@@ -800,10 +784,13 @@ int main(int argc, char* argv[])
}
#if 0
LogRange(std::cout << "out_nkhw: ", out_nkhw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_host : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw_device: ", in_nchw_device.mData, ",") << std::endl;
//LogRange(std::cout << "out_nkhw: ", out_nkhw.mData, ",") << std::endl;
//LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
//LogRange(std::cout << "in_nchw_host : ", in_nchw.mData, ",") << std::endl;
//LogRange(std::cout << "in_nchw_device: ", in_nchw_device.mData, ",") << std::endl;
//LogRange(std::cout << "out_nkhw_host : ", out_nkhw.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif
}
}
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
View file @
f05b210a
...
...
@@ -27,7 +27,7 @@ template <bool isForw,
index_t
Wo
,
class
Strides
,
class
Dilations
>
struct
GetInGlobal
Merge
Desc
;
struct
GetInGlobal
Final
Desc
;
template
<
class
InType
,
index_t
N1
,
...
...
@@ -36,7 +36,7 @@ template <class InType,
index_t
Wo
,
class
Strides
,
class
Dilations
>
struct
GetInGlobal
Merge
Desc
<
true
,
InType
,
N1
,
N2
,
Ho
,
Wo
,
Strides
,
Dilations
>
struct
GetInGlobal
Final
Desc
<
true
,
InType
,
N1
,
N2
,
Ho
,
Wo
,
Strides
,
Dilations
>
{
__host__
__device__
constexpr
auto
GetDesc
()
{
...
...
@@ -102,7 +102,7 @@ template <class InType,
index_t
Wo
,
class
Strides
,
class
Dilations
>
struct
GetInGlobal
Merge
Desc
<
false
,
InType
,
N1
,
N2
,
Ho
,
Wo
,
Strides
,
Dilations
>
struct
GetInGlobal
Final
Desc
<
false
,
InType
,
N1
,
N2
,
Ho
,
Wo
,
Strides
,
Dilations
>
{
__host__
__device__
constexpr
auto
GetDesc
()
{
...
...
@@ -141,17 +141,17 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
};
template
<
bool
isForw
,
class
OutType
,
class
Strides
>
struct
GetOutGlobal
Merge
Desc
;
struct
GetOutGlobal
Final
Desc
;
template
<
class
OutType
,
class
Strides
>
struct
GetOutGlobal
Merge
Desc
<
true
,
OutType
,
Strides
>
struct
GetOutGlobal
Final
Desc
<
true
,
OutType
,
Strides
>
{
__host__
__device__
constexpr
auto
GetDesc
()
{
return
OutType
{};
}
};
template
<
class
OutType
,
class
Strides
>
struct
GetOutGlobal
Merge
Desc
<
false
,
OutType
,
Strides
>
struct
GetOutGlobal
Final
Desc
<
false
,
OutType
,
Strides
>
{
__host__
__device__
constexpr
auto
GetDesc
()
{
...
...
@@ -195,6 +195,24 @@ struct GetOutGlobalMergeDesc<false, OutType, Strides>
}
};
template
<
bool
isForw
,
class
WeiType
>
struct
GetWeiFinalDesc
;
template
<
class
WeiType
>
struct
GetWeiFinalDesc
<
true
,
WeiType
>
{
__host__
__device__
constexpr
auto
GetDesc
()
{
return
WeiType
{}.
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
}
};
template
<
class
WeiType
>
struct
GetWeiFinalDesc
<
false
,
WeiType
>
{
__host__
__device__
constexpr
auto
GetDesc
()
{
return
WeiType
{};
}
};
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
...
...
@@ -311,7 +329,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
constexpr
bool
fwd
=
Direction
==
1
;
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
GetInGlobal
Merge
Desc
<
fwd
,
GetInGlobal
Final
Desc
<
fwd
,
decltype
(
in_n_c_h_w_global_desc
),
N1
,
N2
,
...
...
@@ -352,7 +370,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_global_desc
;
constexpr
auto
wei_e_k_global_desc
=
GetWeiFinalDesc
<
fwd
,
decltype
(
wei_k_c_global_desc
)
>
{}.
GetDesc
();
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
...
...
@@ -576,7 +595,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc
=
GetOutGlobal
Merge
Desc
<
fwd
,
GetOutGlobal
Final
Desc
<
fwd
,
decltype
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
),
Strides
>
{}
.
GetDesc
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment