Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
08c00140
Commit
08c00140
authored
Apr 13, 2021
by
Jing Zhang
Browse files
int8
parent
e273d4d3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
24 deletions
+25
-24
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
...tion_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
+4
-4
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+10
-10
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+8
-7
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+3
-3
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
View file @
08c00140
...
@@ -47,7 +47,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -47,7 +47,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
FloatAB
*
__restrict__
p_wei_global
,
const
FloatAB
*
__restrict__
p_wei_global
,
const
FloatAB
*
__restrict__
p_in_global
,
const
FloatAB
*
__restrict__
p_in_global
,
Float
AB
*
__restrict__
p_d_global
,
Float
C
*
__restrict__
p_d_global
,
FloatC
*
__restrict__
p_out_global
)
const
FloatC
*
__restrict__
p_out_global
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -151,12 +151,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -151,12 +151,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
// add tensor
// add tensor
const
auto
add_k_n_hopx2_wopx2_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
add_k_n_hopx2_wopx2_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
,
1
)),
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
)),
make_tuple
(
make_
merge_transform
(
make_tuple
(
K0
,
1
)
),
make_tuple
(
make_
pass_through_transform
(
K0
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pad_transform
(
Hox2
,
0
,
AddRightPadH
),
make_pad_transform
(
Hox2
,
0
,
AddRightPadH
),
make_pad_transform
(
Wox2
,
0
,
AddRightPadW
)),
make_pad_transform
(
Wox2
,
0
,
AddRightPadW
)),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
E
=
C
*
Y
*
X
;
const
auto
E
=
C
*
Y
*
X
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
08c00140
...
@@ -382,12 +382,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -382,12 +382,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
FloatAB
p_d_thread
[
d_k_n_hox2_wox2_thread_desc
.
GetElementSpaceSize
()];
FloatAB
p_d_thread
[
d_k_n_hox2_wox2_thread_desc
.
GetElementSpaceSize
()];
constexpr
auto
vector_len
=
sizeof
(
FloatAB
)
/
sizeof
(
FloatC
)
;
constexpr
auto
vector_len
=
CThreadTransferDstScalarPerVector
;
static_assert
(
vector_len
==
CThreadTransferDstScalarPerVector
);
static_assert
(
vector_len
==
16
);
constexpr
auto
c_k_n_ho_wo_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
constexpr
auto
c_k_n_ho_wo_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
#if
0
#if
1
ThreadwiseDynamicTensorSliceTransfer_v2
<
ThreadwiseDynamicTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
...
@@ -423,15 +423,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -423,15 +423,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
{
{
for
(
index_t
w_i
=
0
;
w_i
<
WoPerThreadx2
;
++
w_i
)
for
(
index_t
w_i
=
0
;
w_i
<
WoPerThreadx2
;
++
w_i
)
{
{
vector_type
<
FloatC
,
vector_len
>
d_vec
;
vector_type
<
int8_t
,
vector_len
>
d_vec
;
d_vec
.
Vector
()
=
p_d_thread
[
d_k_n_hox2_wox2_thread_desc
.
CalculateOffset
(
d_vec
.
Vector
()
=
p_d_thread
[
d_k_n_hox2_wox2_thread_desc
.
CalculateOffset
(
make_tuple
(
k_i
,
0
,
h_i
,
w_i
))];
make_tuple
(
k_i
,
0
,
h_i
,
w_i
))];
static_for
<
0
,
vector_len
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
vector_len
,
1
>
{}([
&
](
auto
i
)
{
d_vec
.
Scalars
()(
i
)
=
0
;
d_vec
.
Scalars
()(
i
)
+
=
1
;
//p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
//
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
//make_tuple(k_i * vector_len + i, 0, h_i / 2, w_i / 2))];
//
make_tuple(k_i * vector_len + i, 0, h_i / 2, w_i / 2))];
});
});
p_d_thread
[
d_k_n_hox2_wox2_thread_desc
.
CalculateOffset
(
p_d_thread
[
d_k_n_hox2_wox2_thread_desc
.
CalculateOffset
(
...
@@ -465,7 +465,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -465,7 +465,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_d_thread
,
p_d_thread
,
d_k_n_hox2_wox2_global_desc
,
d_k_n_hox2_wox2_global_desc
,
p_
d
_global
,
p_
c
_global
,
c_k_n_ho_wo_global_tensor_iterator_hacks
);
c_k_n_ho_wo_global_tensor_iterator_hacks
);
#endif
#endif
}
}
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
08c00140
...
@@ -91,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -91,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C0
,
Y
,
X
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C0
,
Y
,
X
));
const
auto
out_n_k0_ho_wo_k1_desc
=
const
auto
out_n_k0_ho_wo_k1_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
const
auto
add_n_k0_hox2_wox2_
k1_
desc
=
const
auto
add_n_k0_hox2_wox2_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
,
1
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
));
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
...
@@ -156,7 +156,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -156,7 +156,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
K1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
K1
;
static_assert
(
KPerThread
%
CThreadTransferDstScalarPerVector_W
==
0
,
""
);
//
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#else
#else
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
...
@@ -192,7 +192,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -192,7 +192,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
<
BlockSize
,
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TAcc
,
TOut
,
typename
vector_type
<
TOut
,
InWeiVectorSize
>::
type
,
KPerBlock
,
KPerBlock
,
HoPerBlock
,
HoPerBlock
,
WoPerBlock
,
WoPerBlock
,
...
@@ -210,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -210,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_driver
.
Run
(
wei_k_c0_y_x_desc
,
conv_driver
.
Run
(
wei_k_c0_y_x_desc
,
in_n_c0_hi_wi_desc
,
in_n_c0_hi_wi_desc
,
add_n_k0_hox2_wox2_
k1_
desc
,
add_n_k0_hox2_wox2_desc
,
out_n_k0_ho_wo_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
conv_strides
,
conv_strides
,
conv_dilations
,
conv_dilations
,
...
@@ -220,9 +220,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -220,9 +220,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
T
InWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
T
Out
,
InWeiVectorSize
>::
type
*>
(
add_n_k_hox2_wox2_device_buf
.
GetDeviceBuffer
()),
add_n_k_hox2_wox2_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_hox2_wox2_device_buf
.
GetDeviceBuffer
()));
static_cast
<
typename
vector_type
<
TOut
,
InWeiVectorSize
>::
type
*>
(
out_n_k_hox2_wox2_device_buf
.
GetDeviceBuffer
()));
out_n_k_hox2_wox2_device_buf
.
FromDevice
(
out_n_k0_hox2_wox2_k1
.
mData
.
data
());
out_n_k_hox2_wox2_device_buf
.
FromDevice
(
out_n_k0_hox2_wox2_k1
.
mData
.
data
());
...
...
driver/src/conv_driver.cpp
View file @
08c00140
...
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
...
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
1
#elif
0
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
64
;
constexpr
index_t
HI
=
64
;
...
@@ -637,7 +637,7 @@ int main(int argc, char* argv[])
...
@@ -637,7 +637,7 @@ int main(int argc, char* argv[])
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
#if
1
#if
0
using in_data_t = float;
using in_data_t = float;
constexpr index_t in_vector_size = 1;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using acc_data_t = float;
...
@@ -654,7 +654,7 @@ int main(int argc, char* argv[])
...
@@ -654,7 +654,7 @@ int main(int argc, char* argv[])
using
out_data_t
=
int8_t
;
using
out_data_t
=
int8_t
;
#elif 1
#elif 1
using
in_data_t
=
int8_t
;
using
in_data_t
=
int8_t
;
constexpr
index_t
in_vector_size
=
4
;
constexpr
index_t
in_vector_size
=
16
;
using
acc_data_t
=
int32_t
;
using
acc_data_t
=
int32_t
;
using
out_data_t
=
int8_t
;
using
out_data_t
=
int8_t
;
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment