Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5742d293
You need to sign in or sign up before continuing.
Commit
5742d293
authored
Jun 20, 2022
by
carlushuang
Browse files
add another type of direct
parent
974348d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
230 additions
and
14 deletions
+230
-14
include/ck/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
...k/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
+228
-13
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
...vice_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
+2
-1
No files found.
include/ck/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
View file @
5742d293
...
@@ -448,8 +448,8 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -448,8 +448,8 @@ struct GridwiseDirectConvNHWCAvx2
distribute_num_threads_nho_wo_k
(
num_threads_nho
,
num_threads_wo
,
num_threads_k
);
distribute_num_threads_nho_wo_k
(
num_threads_nho
,
num_threads_wo
,
num_threads_k
);
//
const ck::index_t num_works_nho_per_thread =
math::integer_divide_ceil( num_works_nho,
const
ck
::
index_t
num_works_nho_per_thread
=
//
num_threads_nho);
math
::
integer_divide_ceil
(
num_works_nho
,
num_threads_nho
);
const
ck
::
index_t
num_works_wo_per_thread
=
const
ck
::
index_t
num_works_wo_per_thread
=
math
::
integer_divide_ceil
(
num_works_wo
,
num_threads_wo
);
math
::
integer_divide_ceil
(
num_works_wo
,
num_threads_wo
);
const
ck
::
index_t
num_works_k_per_thread
=
const
ck
::
index_t
num_works_k_per_thread
=
...
@@ -463,16 +463,220 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -463,16 +463,220 @@ struct GridwiseDirectConvNHWCAvx2
if
(
dynamic_tunable
.
loop_over_spec
==
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MNK
)
LoopOver_MNK
)
{}
{
// only parallel in gemm m dim
#pragma omp parallel
{
DeviceAlignedMemCPU
a_block_mem
(
UseALocalBuffer
?
m_per_thread
*
k_per_thread
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_thread
*
n_per_thread
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_thread
*
n_per_thread
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
:
c_grid_desc
.
GetElementSpaceSize
());
const
ck
::
index_t
tid
=
omp_get_thread_num
();
const
ck
::
index_t
tid_k
=
tid
%
num_threads_k
;
const
ck
::
index_t
tid_wo
=
(
tid
/
num_threads_k
)
%
num_threads_wo
;
const
ck
::
index_t
tid_nho
=
tid
/
(
num_threads_k
*
num_threads_wo
);
ck
::
cpu
::
ThreadwiseGemmParam
param
;
// param.Kr = k_per_block;
param
.
lda
=
Sx
*
C
*
sizeof
(
FloatA
);
param
.
ldb
=
GetBLeadingElement
(
b_grid_desc
)
*
sizeof
(
FloatB
);
param
.
ldc
=
GetCLeadingElement
(
c_grid_desc
)
*
sizeof
(
FloatC
);
param
.
alpha
=
1.0
f
;
// TODO
param
.
Kr
=
C
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
ck
::
index_t
i_nho
=
tid_nho
*
num_works_nho_per_thread
;
ck
::
index_t
i_ho
=
i_nho
%
Ho
;
ck
::
index_t
i_n
=
i_nho
/
Ho
;
auto
accumulate_n_ho
=
[
&
]()
{
i_ho
++
;
if
(
i_ho
>=
Wo
)
{
i_ho
=
0
;
i_n
++
;
}
};
for
(;
i_nho
<
(
tid_nho
+
1
)
*
num_works_nho_per_thread
;
i_nho
+=
1
,
accumulate_n_ho
())
{
// for input
ck
::
index_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
for
(
ck
::
index_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
+=
m_per_thread
)
{
ck
::
index_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
m_per_thread
);
ck
::
index_t
i_wi_no_x
=
i_wo
*
Sx
-
Px
;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi, Wi);fflush(stdout);
for
(
ck
::
index_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
i_k
<
(
tid_k
+
1
)
*
num_works_k_per_thread
*
n_per_thread
;
i_k
+=
n_per_thread
)
{
ck
::
index_t
i_dx
=
0
;
ck
::
index_t
i_dy
=
0
;
bool
accmulate_c
=
false
;
ck
::
index_t
current_k_size
=
ck
::
math
::
min
(
K
-
i_k
,
n_per_thread
);
auto
accumulate_dy_dx
=
[
&
]()
{
i_dx
+=
Dx
;
if
(
i_dx
>=
X_Dx
)
{
i_dx
=
0
;
i_dy
+=
Dy
;
}
};
for
(
ck
::
index_t
i_yxc
=
0
;
i_yxc
<
(
Y
*
X
*
C
);
i_yxc
+=
C
,
accumulate_dy_dx
())
{
ck
::
index_t
current_i_wo
=
i_wo
;
ck
::
index_t
i_hi
=
i_hi_no_y
+
i_dy
;
if
(
i_hi
<
0
||
i_hi
>=
Hi
)
continue
;
ck
::
index_t
i_wi
=
i_wi_no_x
+
i_dx
;
ck
::
index_t
current_wo_size
=
current_wo_size_no_dx
;
ck
::
index_t
pad_wo_size
=
0
;
// when left pad, we may never have a
// chance to clear zero (like
// padding) we need to manually clear that
/* left corner shift
* when i_wi is negative, need shift i_wo to right to make i_wi
* possitive sx px i_wi steps_wo_turn_possitive 1 0
* 0, 1, 2.... 0 2 0 0, 2, 4... 0 1 1 -1,
* 0, 1.... 1 2 1 -1, 1, 3.... 1 2 2 -2, 0, 2... 1 2
* 3 -3, -1, 1... 2 3 1 -1, 2, 5... 1 3 2 -2,
* 1, 4.... 1 3 3 -3, 0, 3 1 3 4 -4,
* -1, 2... 2
*/
if
(
i_wi
<
0
)
{
ck
::
index_t
wi_to_zero_length
=
-
i_wi
;
// keep this a possitive number
ck
::
index_t
steps_wo_turn_possitive
=
(
wi_to_zero_length
+
Sx
-
1
)
/
Sx
;
// how many steps need to move wo, to let wi to be
// possitive
current_wo_size
-=
steps_wo_turn_possitive
;
if
(
current_wo_size
<=
0
)
continue
;
current_i_wo
+=
steps_wo_turn_possitive
;
if
(
!
accmulate_c
)
pad_wo_size
=
steps_wo_turn_possitive
;
// if already accumulating, no
// need to manually set
i_wi
+=
steps_wo_turn_possitive
*
Sx
;
// now i_wi will be a possitive number
}
// shrink right wi/wo
if
((
i_wi
+
((
current_wo_size
-
1
)
*
Sx
))
>=
Wi
)
{
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi +
// ((current_wo_size - 1) * Sx), current_wo_size);
current_wo_size
=
(
Wi
-
1
-
i_wi
)
/
Sx
+
1
;
// NOTE: this be careful why here
// should be compute like this.
if
(
current_wo_size
<=
0
)
continue
;
}
param
.
accmulate_c
=
accmulate_c
?
1
:
0
;
accmulate_c
=
true
;
intptr_t
current_input_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
;
if
(
pad_wo_size
!=
0
)
{
// manually clear zero. this may and only may need once along
// the gemm_k reduction
// ck::index_t i_k = tid_k * num_works_k_per_thread *
// n_per_thread; ck::index_t current_k_block_size =
// ck::math::min(K - i_k, num_works_k_per_thread *
// n_per_thread);
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
i_nho
*
Wo
,
i_k
);
// printf("pad_wo_size:%d, current_k_block_size:%d, clear
// offset_c:%d\n",
// pad_wo_size, current_k_size * pad_wo_size,
// offset_c);fflush(stdout);
ck
::
cpu
::
avx2_util
::
memset32_avx2
(
&
c_block_buf
.
p_data_
[
offset_c
],
0
,
current_k_size
*
pad_wo_size
);
}
const
intptr_t
offset_a
=
current_input_offset
;
const
intptr_t
offset_b
=
GetBBlockStartOffset
(
b_grid_desc
,
i_yxc
,
i_k
);
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
i_nho
*
Wo
+
current_i_wo
,
i_k
);
// printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d,
// i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d,
// current_wo_size:%d, current_k_size:%d, i_nho:%d, lda:%d, ldb:%d,
// ldc:%d, acc:%d",
// offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx, i_dy,
// i_k, i_ho, current_i_wo, current_wo_size, current_k_size,
// i_nho, param.lda / sizeof(FloatA), param.ldb /
// sizeof(FloatB), param.ldc / sizeof(FloatC),
// param.accmulate_c); fflush(stdout);
param
.
p_a
=
&
a_block_buf
.
p_data_
[
offset_a
];
param
.
p_b
=
&
b_block_buf
.
p_data_
[
offset_b
];
param
.
p_c
=
&
c_block_buf
.
p_data_
[
offset_c
];
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_wo_size
,
current_k_size
);
// printf(" ------ \n");fflush(stdout);
}
}
}
}
}
}
else
if
(
dynamic_tunable
.
loop_over_spec
==
else
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
)
LoopOver_MKN
)
{
{
// always parallel on N*Ho. single thread will deal with whole Wo. Hence we split this
// problem
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_thread
);
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_thread
);
// only parallel in gemm m dim
// only parallel in gemm m dim
#pragma omp parallel
#pragma omp parallel
{
{
...
@@ -519,17 +723,28 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -519,17 +723,28 @@ struct GridwiseDirectConvNHWCAvx2
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
ck
::
index_t
i_nho
=
tid_nho
*
num_works_nho_per_thread
;
ck
::
index_t
i_ho
=
i_nho
%
Ho
;
ck
::
index_t
i_n
=
i_nho
/
Ho
;
auto
accumulate_n_ho
=
[
&
]()
{
i_ho
++
;
if
(
i_ho
>=
Wo
)
{
i_ho
=
0
;
i_n
++
;
}
};
for
(
ck
::
index_t
i_nho
=
tid_nho
;
i_nho
<
num_works_nho
;
i_nho
+=
num_threads_nho
)
for
(;
i_nho
<
(
tid_nho
+
1
)
*
num_works_nho_per_thread
;
i_nho
+=
1
,
accumulate_n_ho
())
{
{
// for input
// for input
ck
::
index_t
i_ho
=
i_nho
%
Ho
;
ck
::
index_t
i_n
=
i_nho
/
Ho
;
ck
::
index_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
ck
::
index_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
for
(
ck
::
index_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
for
(
ck
::
index_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
Wo
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
+=
m_per_thread
)
i_wo
+=
m_per_thread
)
{
{
ck
::
index_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
m_per_thread
);
ck
::
index_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
m_per_thread
);
...
@@ -631,7 +846,7 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -631,7 +846,7 @@ struct GridwiseDirectConvNHWCAvx2
ck
::
cpu
::
avx2_util
::
memset32_avx2
(
&
c_block_buf
.
p_data_
[
offset_c
],
ck
::
cpu
::
avx2_util
::
memset32_avx2
(
&
c_block_buf
.
p_data_
[
offset_c
],
0
,
0
,
current_k_block_size
*
current_k_block_size
*
pad_wo_size
*
sizeof
(
FloatC
)
);
pad_wo_size
);
}
}
for
(
ck
::
index_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
for
(
ck
::
index_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
5742d293
...
@@ -62,7 +62,8 @@ void add_device_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk(
...
@@ -62,7 +62,8 @@ void add_device_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DeviceConvNDDirectFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
PT
,
PT
,
PT
,
ConvFwdDefault
,
2
,
4
,
24
,
false
,
false
,
false
>
({
0
,
0
,
0
,
DefaultGemmKLoop
,
LoopOver_MKN
})
DeviceConvNDDirectFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
PT
,
PT
,
PT
,
ConvFwdDefault
,
2
,
6
,
16
,
false
,
false
,
false
>
({
0
,
0
,
0
,
DefaultGemmKLoop
,
LoopOver_MKN
}),
DeviceConvNDDirectFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
PT
,
PT
,
PT
,
ConvFwdDefault
,
2
,
6
,
16
,
false
,
false
,
false
>
({
0
,
0
,
0
,
DefaultGemmKLoop
,
LoopOver_MNK
})
// clang-format on
// clang-format on
));
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment