Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5742d293
Commit
5742d293
authored
Jun 20, 2022
by
carlushuang
Browse files
add another type of direct
parent
974348d6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
230 additions
and
14 deletions
+230
-14
include/ck/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
...k/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
+228
-13
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
...vice_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
+2
-1
No files found.
include/ck/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
View file @
5742d293
...
...
@@ -448,8 +448,8 @@ struct GridwiseDirectConvNHWCAvx2
distribute_num_threads_nho_wo_k
(
num_threads_nho
,
num_threads_wo
,
num_threads_k
);
//
const ck::index_t num_works_nho_per_thread =
math::integer_divide_ceil( num_works_nho,
//
num_threads_nho);
const
ck
::
index_t
num_works_nho_per_thread
=
math
::
integer_divide_ceil
(
num_works_nho
,
num_threads_nho
);
const
ck
::
index_t
num_works_wo_per_thread
=
math
::
integer_divide_ceil
(
num_works_wo
,
num_threads_wo
);
const
ck
::
index_t
num_works_k_per_thread
=
...
...
@@ -463,16 +463,220 @@ struct GridwiseDirectConvNHWCAvx2
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MNK
)
{}
{
// only parallel in gemm m dim
#pragma omp parallel
{
DeviceAlignedMemCPU
a_block_mem
(
UseALocalBuffer
?
m_per_thread
*
k_per_thread
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_thread
*
n_per_thread
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_thread
*
n_per_thread
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
:
c_grid_desc
.
GetElementSpaceSize
());
const
ck
::
index_t
tid
=
omp_get_thread_num
();
const
ck
::
index_t
tid_k
=
tid
%
num_threads_k
;
const
ck
::
index_t
tid_wo
=
(
tid
/
num_threads_k
)
%
num_threads_wo
;
const
ck
::
index_t
tid_nho
=
tid
/
(
num_threads_k
*
num_threads_wo
);
ck
::
cpu
::
ThreadwiseGemmParam
param
;
// param.Kr = k_per_block;
param
.
lda
=
Sx
*
C
*
sizeof
(
FloatA
);
param
.
ldb
=
GetBLeadingElement
(
b_grid_desc
)
*
sizeof
(
FloatB
);
param
.
ldc
=
GetCLeadingElement
(
c_grid_desc
)
*
sizeof
(
FloatC
);
param
.
alpha
=
1.0
f
;
// TODO
param
.
Kr
=
C
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
ck
::
index_t
i_nho
=
tid_nho
*
num_works_nho_per_thread
;
ck
::
index_t
i_ho
=
i_nho
%
Ho
;
ck
::
index_t
i_n
=
i_nho
/
Ho
;
auto
accumulate_n_ho
=
[
&
]()
{
i_ho
++
;
if
(
i_ho
>=
Wo
)
{
i_ho
=
0
;
i_n
++
;
}
};
for
(;
i_nho
<
(
tid_nho
+
1
)
*
num_works_nho_per_thread
;
i_nho
+=
1
,
accumulate_n_ho
())
{
// for input
ck
::
index_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
for
(
ck
::
index_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
+=
m_per_thread
)
{
ck
::
index_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
m_per_thread
);
ck
::
index_t
i_wi_no_x
=
i_wo
*
Sx
-
Px
;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi, Wi);fflush(stdout);
for
(
ck
::
index_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
i_k
<
(
tid_k
+
1
)
*
num_works_k_per_thread
*
n_per_thread
;
i_k
+=
n_per_thread
)
{
ck
::
index_t
i_dx
=
0
;
ck
::
index_t
i_dy
=
0
;
bool
accmulate_c
=
false
;
ck
::
index_t
current_k_size
=
ck
::
math
::
min
(
K
-
i_k
,
n_per_thread
);
auto
accumulate_dy_dx
=
[
&
]()
{
i_dx
+=
Dx
;
if
(
i_dx
>=
X_Dx
)
{
i_dx
=
0
;
i_dy
+=
Dy
;
}
};
for
(
ck
::
index_t
i_yxc
=
0
;
i_yxc
<
(
Y
*
X
*
C
);
i_yxc
+=
C
,
accumulate_dy_dx
())
{
ck
::
index_t
current_i_wo
=
i_wo
;
ck
::
index_t
i_hi
=
i_hi_no_y
+
i_dy
;
if
(
i_hi
<
0
||
i_hi
>=
Hi
)
continue
;
ck
::
index_t
i_wi
=
i_wi_no_x
+
i_dx
;
ck
::
index_t
current_wo_size
=
current_wo_size_no_dx
;
ck
::
index_t
pad_wo_size
=
0
;
// when left pad, we may never have a
// chance to clear zero (like
// padding) we need to manually clear that
/* left corner shift
* when i_wi is negative, need shift i_wo to right to make i_wi
* possitive sx px i_wi steps_wo_turn_possitive 1 0
* 0, 1, 2.... 0 2 0 0, 2, 4... 0 1 1 -1,
* 0, 1.... 1 2 1 -1, 1, 3.... 1 2 2 -2, 0, 2... 1 2
* 3 -3, -1, 1... 2 3 1 -1, 2, 5... 1 3 2 -2,
* 1, 4.... 1 3 3 -3, 0, 3 1 3 4 -4,
* -1, 2... 2
*/
if
(
i_wi
<
0
)
{
ck
::
index_t
wi_to_zero_length
=
-
i_wi
;
// keep this a possitive number
ck
::
index_t
steps_wo_turn_possitive
=
(
wi_to_zero_length
+
Sx
-
1
)
/
Sx
;
// how many steps need to move wo, to let wi to be
// possitive
current_wo_size
-=
steps_wo_turn_possitive
;
if
(
current_wo_size
<=
0
)
continue
;
current_i_wo
+=
steps_wo_turn_possitive
;
if
(
!
accmulate_c
)
pad_wo_size
=
steps_wo_turn_possitive
;
// if already accumulating, no
// need to manually set
i_wi
+=
steps_wo_turn_possitive
*
Sx
;
// now i_wi will be a possitive number
}
// shrink right wi/wo
if
((
i_wi
+
((
current_wo_size
-
1
)
*
Sx
))
>=
Wi
)
{
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi +
// ((current_wo_size - 1) * Sx), current_wo_size);
current_wo_size
=
(
Wi
-
1
-
i_wi
)
/
Sx
+
1
;
// NOTE: this be careful why here
// should be compute like this.
if
(
current_wo_size
<=
0
)
continue
;
}
param
.
accmulate_c
=
accmulate_c
?
1
:
0
;
accmulate_c
=
true
;
intptr_t
current_input_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
;
if
(
pad_wo_size
!=
0
)
{
// manually clear zero. this may and only may need once along
// the gemm_k reduction
// ck::index_t i_k = tid_k * num_works_k_per_thread *
// n_per_thread; ck::index_t current_k_block_size =
// ck::math::min(K - i_k, num_works_k_per_thread *
// n_per_thread);
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
i_nho
*
Wo
,
i_k
);
// printf("pad_wo_size:%d, current_k_block_size:%d, clear
// offset_c:%d\n",
// pad_wo_size, current_k_size * pad_wo_size,
// offset_c);fflush(stdout);
ck
::
cpu
::
avx2_util
::
memset32_avx2
(
&
c_block_buf
.
p_data_
[
offset_c
],
0
,
current_k_size
*
pad_wo_size
);
}
const
intptr_t
offset_a
=
current_input_offset
;
const
intptr_t
offset_b
=
GetBBlockStartOffset
(
b_grid_desc
,
i_yxc
,
i_k
);
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
i_nho
*
Wo
+
current_i_wo
,
i_k
);
// printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d,
// i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d,
// current_wo_size:%d, current_k_size:%d, i_nho:%d, lda:%d, ldb:%d,
// ldc:%d, acc:%d",
// offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx, i_dy,
// i_k, i_ho, current_i_wo, current_wo_size, current_k_size,
// i_nho, param.lda / sizeof(FloatA), param.ldb /
// sizeof(FloatB), param.ldc / sizeof(FloatC),
// param.accmulate_c); fflush(stdout);
param
.
p_a
=
&
a_block_buf
.
p_data_
[
offset_a
];
param
.
p_b
=
&
b_block_buf
.
p_data_
[
offset_b
];
param
.
p_c
=
&
c_block_buf
.
p_data_
[
offset_c
];
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_wo_size
,
current_k_size
);
// printf(" ------ \n");fflush(stdout);
}
}
}
}
}
}
else
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
)
{
// always parallel on N*Ho. single thread will deal with whole Wo. Hence we split this
// problem
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_thread
);
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_thread
);
// only parallel in gemm m dim
#pragma omp parallel
{
...
...
@@ -519,17 +723,28 @@ struct GridwiseDirectConvNHWCAvx2
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
ck
::
index_t
i_nho
=
tid_nho
*
num_works_nho_per_thread
;
ck
::
index_t
i_ho
=
i_nho
%
Ho
;
ck
::
index_t
i_n
=
i_nho
/
Ho
;
auto
accumulate_n_ho
=
[
&
]()
{
i_ho
++
;
if
(
i_ho
>=
Wo
)
{
i_ho
=
0
;
i_n
++
;
}
};
for
(
ck
::
index_t
i_nho
=
tid_nho
;
i_nho
<
num_works_nho
;
i_nho
+=
num_threads_nho
)
for
(;
i_nho
<
(
tid_nho
+
1
)
*
num_works_nho_per_thread
;
i_nho
+=
1
,
accumulate_n_ho
())
{
// for input
ck
::
index_t
i_ho
=
i_nho
%
Ho
;
ck
::
index_t
i_n
=
i_nho
/
Ho
;
ck
::
index_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
for
(
ck
::
index_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
Wo
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
+=
m_per_thread
)
{
ck
::
index_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
m_per_thread
);
...
...
@@ -631,7 +846,7 @@ struct GridwiseDirectConvNHWCAvx2
ck
::
cpu
::
avx2_util
::
memset32_avx2
(
&
c_block_buf
.
p_data_
[
offset_c
],
0
,
current_k_block_size
*
pad_wo_size
*
sizeof
(
FloatC
)
);
pad_wo_size
);
}
for
(
ck
::
index_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
5742d293
...
...
@@ -62,7 +62,8 @@ void add_device_conv2d_direct_fwd_avx2_nhwc_kyxck8_nhwk(
instances
,
std
::
make_tuple
(
// clang-format off
DeviceConvNDDirectFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
PT
,
PT
,
PT
,
ConvFwdDefault
,
2
,
4
,
24
,
false
,
false
,
false
>
({
0
,
0
,
0
,
DefaultGemmKLoop
,
LoopOver_MKN
})
DeviceConvNDDirectFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
PT
,
PT
,
PT
,
ConvFwdDefault
,
2
,
6
,
16
,
false
,
false
,
false
>
({
0
,
0
,
0
,
DefaultGemmKLoop
,
LoopOver_MKN
}),
DeviceConvNDDirectFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
PT
,
PT
,
PT
,
ConvFwdDefault
,
2
,
6
,
16
,
false
,
false
,
false
>
({
0
,
0
,
0
,
DefaultGemmKLoop
,
LoopOver_MNK
})
// clang-format on
));
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment