Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bb37eb69
"...text-generation-inference.git" did not exist on "e279b38aca90cddc0ab654e18c369d9c462ebc0d"
Commit
bb37eb69
authored
Jun 11, 2019
by
Jing Zhang
Browse files
merge forw and back
parent
81ec25b4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
20 deletions
+63
-20
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...er/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+3
-2
driver/driver.hip.cpp
driver/driver.hip.cpp
+3
-3
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
...implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
+57
-15
No files found.
driver/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
bb37eb69
...
@@ -56,8 +56,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
...
@@ -56,8 +56,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
N1
=
2
;
constexpr
index_t
N1
=
2
;
constexpr
index_t
N2
=
4
;
constexpr
index_t
N2
=
4
;
//constexpr index_t B = N * mod_conv::integer_divide_ceil(Ho, Strides::Get(I0)) *
//
constexpr index_t B = N * mod_conv::integer_divide_ceil(Ho, Strides::Get(I0)) *
//mod_conv::integer_divide_ceil(Wo, Strides::Get(I1)) / (N1 * N2);
//
mod_conv::integer_divide_ceil(Wo, Strides::Get(I1)) / (N1 * N2);
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
#if 1
...
@@ -113,6 +113,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
...
@@ -113,6 +113,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
BlockSize
,
BlockSize
,
Strides
,
Strides
,
Dilations
,
Dilations
,
Direction
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
wei_kcyx_desc
),
...
...
driver/driver.hip.cpp
View file @
bb37eb69
...
@@ -493,13 +493,13 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -493,13 +493,13 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
constexpr
index_t
HStride
=
2
;
constexpr
index_t
HStride
=
1
;
constexpr
index_t
WStride
=
2
;
constexpr
index_t
WStride
=
1
;
constexpr
index_t
HDilation
=
1
;
constexpr
index_t
HDilation
=
1
;
constexpr
index_t
WDilation
=
1
;
constexpr
index_t
WDilation
=
1
;
constexpr
index_t
Direction
=
2
;
// 1: Forward; 2:Backward
constexpr
index_t
Direction
=
1
;
// 1: Forward; 2:Backward
#if 0
#if 0
constexpr index_t N = 32;
constexpr index_t N = 32;
constexpr index_t C = 128;
constexpr index_t C = 128;
...
...
src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp
View file @
bb37eb69
...
@@ -7,11 +7,14 @@
...
@@ -7,11 +7,14 @@
#include "blockwise_gemm.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "threadwise_generic_tensor_slice_op.hip.hpp"
#include "threadwise_generic_tensor_slice_op.hip.hpp"
#define FORW 1
// define B = merge(N0, Ho, Wo)
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
Strides
,
class
Strides
,
class
Dilations
,
class
Dilations
,
index_t
Direction
,
class
Float
,
class
Float
,
class
InGlobalDesc
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
WeiGlobalDesc
,
...
@@ -50,8 +53,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -50,8 +53,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_conv_out_global
)
const
Float
*
const
__restrict__
p_conv_out_global
)
const
{
{
auto
p_in_global
=
p_conv_out_global
;
auto
p_in_global
=
Direction
==
1
?
p_conv_in_global
:
p_conv_out_global
;
auto
p_out_global
=
p_conv_in_global
;
auto
p_out_global
=
Direction
==
1
?
p_conv_out_global
:
p_conv_in_global
;
// this is a mess
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
...
@@ -71,10 +74,16 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -71,10 +74,16 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
OutGlobalDesc
{};
#if FORW
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
#else
constexpr
auto
in_n_c_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
InGlobalDesc
{};
#endif
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr
auto
wei_k_c_1_1_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
wei_k_c_1_1_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
InGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
...
@@ -112,34 +121,62 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -112,34 +121,62 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
#if FORW
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
#else
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
mod_conv
::
integer_divide_ceil
(
Ho
,
Strides
::
Get
(
I0
))
>
{})
.
Slice
(
I2
,
Number
<
mod_conv
::
integer_divide_ceil
(
Ho
,
Strides
::
Get
(
I0
))
>
{})
.
Slice
(
I3
,
Number
<
mod_conv
::
integer_divide_ceil
(
Wo
,
Strides
::
Get
(
I1
))
>
{})
.
Slice
(
I3
,
Number
<
mod_conv
::
integer_divide_ceil
(
Wo
,
Strides
::
Get
(
I1
))
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
#endif
// constexpr auto in_n0_n1_n2_h_w_global_desc =
#if FORW
// in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
constexpr
auto
in_lengths_new
=
Sequence
<
N0
,
N1
,
N2
,
Ho
,
Wo
>
{};
//.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr
auto
in_strides_new
=
Sequence
<
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I0
),
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I1
),
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I2
),
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I3
)
*
Strides
{}.
Get
(
I0
),
in_n0_n1_n2_h_w_global_desc
.
GetStride
(
I4
)
*
Strides
{}.
Get
(
I1
)
>
{};
// constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr
auto
in_n0_n1_n2_h_w_new_global_desc
=
make_ConstantTensorDescriptor
(
in_lengths_new
,
in_strides_new
);
#else
constexpr
auto
in_n0_n1_n2_h_w_new_global_desc
=
in_n0_n1_n2_h_w_global_desc
;
constexpr
auto
in_n0_n1_n2_h_w_new_global_desc
=
in_n0_n1_n2_h_w_global_desc
;
#endif
// batch descritpor for device memory
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
// to-do: add dilation: keep lengths, modify strides
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Y
>
{})
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Y
>
{})
.
Slice
(
I3
,
Number
<
X
>
{})
.
Slice
(
I3
,
Number
<
X
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
#if FORW
constexpr
auto
in_win_lengths_new
=
Sequence
<
in_c_y_x_global_desc
.
GetLength
(
I0
),
in_c_y_x_global_desc
.
GetLength
(
I1
),
in_c_y_x_global_desc
.
GetLength
(
I2
)
>
{};
constexpr
auto
in_win_strides_new
=
Sequence
<
in_c_y_x_global_desc
.
GetStride
(
I0
),
in_c_y_x_global_desc
.
GetStride
(
I1
)
*
Dilations
{}.
Get
(
I0
),
in_c_y_x_global_desc
.
GetStride
(
I2
)
*
Dilations
{}.
Get
(
I1
)
>
{};
constexpr
auto
in_c_y_x_new_global_desc
=
make_ConstantTensorDescriptor
(
in_win_lengths_new
,
in_win_strides_new
);
#else
constexpr
auto
in_c_y_x_new_global_desc
=
in_c_y_x_global_desc
;
#endif
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_new_global_desc
),
in_c_y_x_
new_
global_desc
.
Embed
(
in_n0_n1_n2_h_w_new_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
3
,
6
,
7
>
{},
...
@@ -174,8 +211,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -174,8 +211,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
InBlockCopyDstDataPerWrite_N2
>
(
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_1_1_global_desc
.
Unfold
(
I1
,
I3
);
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_1_1_global_desc
.
Unfold
(
I1
,
I3
);
// tensor descriptor in LDS, dst of blockwise copy
// tensor descriptor in LDS, dst of blockwise copy
...
@@ -396,6 +433,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -396,6 +433,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
#if FORW
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc
=
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
;
#else
constexpr
auto
out_lengths_new
=
Sequence
<
constexpr
auto
out_lengths_new
=
Sequence
<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
GetLength
(
I0
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
GetLength
(
I0
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
GetLength
(
I1
),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
GetLength
(
I1
),
...
@@ -420,6 +461,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
...
@@ -420,6 +461,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc
=
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc
=
make_ConstantTensorDescriptor
(
out_lengths_new
,
out_strides_new
);
make_ConstantTensorDescriptor
(
out_lengths_new
,
out_strides_new
);
#endif
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment