Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
05d38218
Commit
05d38218
authored
Jul 02, 2022
by
carlushuang
Browse files
fix a bug in direct conv 4G size
parent
e8f639d2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
75 deletions
+76
-75
include/ck/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
...k/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
+71
-70
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+5
-5
No files found.
include/ck/tensor_operation/cpu/grid/gridwise_direct_conv_avx2.hpp
View file @
05d38218
...
@@ -286,8 +286,8 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -286,8 +286,8 @@ struct GridwiseDirectConvNHWCAvx2
return
is_valid
;
return
is_valid
;
}
}
static
ck
::
index
_t
static
intptr
_t
GetBBlockStartOffset
(
const
BGridDesc
&
b_grid_desc
,
const
in
dex
_t
i_k
,
const
in
dex
_t
i_n
)
GetBBlockStartOffset
(
const
BGridDesc
&
b_grid_desc
,
const
in
tptr
_t
i_k
,
const
in
tptr
_t
i_n
)
{
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
...
@@ -303,13 +303,13 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -303,13 +303,13 @@ struct GridwiseDirectConvNHWCAvx2
}
}
}
}
static
ck
::
index
_t
static
intptr
_t
GetCBlockStartOffset
(
const
CGridDesc
&
c_grid_desc
,
const
in
dex
_t
i_m
,
const
in
dex
_t
i_n
)
GetCBlockStartOffset
(
const
CGridDesc
&
c_grid_desc
,
const
in
tptr
_t
i_m
,
const
in
tptr
_t
i_n
)
{
{
return
i_m
*
c_grid_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
+
i_n
;
return
i_m
*
c_grid_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
+
i_n
;
}
}
static
ck
::
index
_t
GetBLeadingElement
(
const
BGridDesc
&
b_grid_desc
)
static
intptr
_t
GetBLeadingElement
(
const
BGridDesc
&
b_grid_desc
)
{
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
...
@@ -324,7 +324,7 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -324,7 +324,7 @@ struct GridwiseDirectConvNHWCAvx2
}
}
}
}
static
ck
::
index
_t
GetCLeadingElement
(
const
CGridDesc
&
c_grid_desc
)
static
intptr
_t
GetCLeadingElement
(
const
CGridDesc
&
c_grid_desc
)
{
{
return
c_grid_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
return
c_grid_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
...
@@ -357,25 +357,25 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -357,25 +357,25 @@ struct GridwiseDirectConvNHWCAvx2
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
a_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
a_grid_desc
.
GetLength
(
I1
);
const
in
dex
_t
Hi
=
input_spatial_lengths
[
0
];
const
in
tptr
_t
Hi
=
input_spatial_lengths
[
0
];
const
in
dex
_t
Wi
=
input_spatial_lengths
[
1
];
const
in
tptr
_t
Wi
=
input_spatial_lengths
[
1
];
const
in
dex
_t
Ho
=
output_spatial_lengths
[
0
];
const
in
tptr
_t
Ho
=
output_spatial_lengths
[
0
];
const
in
dex
_t
Wo
=
output_spatial_lengths
[
1
];
const
in
tptr
_t
Wo
=
output_spatial_lengths
[
1
];
const
in
dex
_t
Y
=
filter_spatial_lengths
[
0
];
const
in
tptr
_t
Y
=
filter_spatial_lengths
[
0
];
const
in
dex
_t
X
=
filter_spatial_lengths
[
1
];
const
in
tptr
_t
X
=
filter_spatial_lengths
[
1
];
const
in
dex
_t
Sy
=
conv_filter_strides
[
0
];
const
in
tptr
_t
Sy
=
conv_filter_strides
[
0
];
const
in
dex
_t
Sx
=
conv_filter_strides
[
1
];
const
in
tptr
_t
Sx
=
conv_filter_strides
[
1
];
const
in
dex
_t
Dy
=
conv_filter_dilations
[
0
];
const
in
tptr
_t
Dy
=
conv_filter_dilations
[
0
];
const
in
dex
_t
Dx
=
conv_filter_dilations
[
1
];
const
in
tptr
_t
Dx
=
conv_filter_dilations
[
1
];
const
in
dex
_t
Py
=
input_left_pads
[
0
];
const
in
tptr
_t
Py
=
input_left_pads
[
0
];
const
in
dex
_t
Px
=
input_left_pads
[
1
];
const
in
tptr
_t
Px
=
input_left_pads
[
1
];
const
in
dex
_t
X_Dx
=
X
*
Dx
;
const
in
tptr
_t
X_Dx
=
X
*
Dx
;
// const index_t Y_Dy = Y * Dy;
// const index_t Y_Dy = Y * Dy;
// const index_t InRightPadH = input_right_pads[0];
// const index_t InRightPadH = input_right_pads[0];
...
@@ -421,11 +421,11 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -421,11 +421,11 @@ struct GridwiseDirectConvNHWCAvx2
}
}
return
t_
;
return
t_
;
};
};
const
ck
::
index
_t
num_works_n
=
N
;
const
intptr
_t
num_works_n
=
N
;
const
ck
::
index
_t
num_works_ho
=
Ho
;
const
intptr
_t
num_works_ho
=
Ho
;
// const
ck::index
_t num_works_nho = N * Ho;
// const
intptr
_t num_works_nho = N * Ho;
const
ck
::
index
_t
num_works_wo
=
math
::
integer_divide_ceil
(
Wo
,
m_per_thread
);
const
intptr
_t
num_works_wo
=
math
::
integer_divide_ceil
(
Wo
,
m_per_thread
);
const
ck
::
index
_t
num_works_k
=
math
::
integer_divide_ceil
(
K
,
n_per_thread
);
const
intptr
_t
num_works_k
=
math
::
integer_divide_ceil
(
K
,
n_per_thread
);
auto
distribute_num_threads_n_ho_wo_k
=
[
&
](
ck
::
index_t
&
num_threads_n_
,
auto
distribute_num_threads_n_ho_wo_k
=
[
&
](
ck
::
index_t
&
num_threads_n_
,
ck
::
index_t
&
num_threads_ho_
,
ck
::
index_t
&
num_threads_ho_
,
...
@@ -545,40 +545,41 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -545,40 +545,41 @@ struct GridwiseDirectConvNHWCAvx2
// }
// }
// };
// };
for
(
ck
::
index
_t
i_n
=
tid_n
*
num_works_n_per_thread
;
for
(
intptr
_t
i_n
=
tid_n
*
num_works_n_per_thread
;
(
i_n
<
(
tid_n
+
1
)
*
num_works_n_per_thread
)
&&
i_n
<
num_works_n
;
(
i_n
<
(
tid_n
+
1
)
*
num_works_n_per_thread
)
&&
i_n
<
num_works_n
;
i_n
+=
1
)
i_n
+=
1
)
{
{
for
(
ck
::
index
_t
i_ho
=
tid_ho
*
num_works_ho_per_thread
;
for
(
intptr
_t
i_ho
=
tid_ho
*
num_works_ho_per_thread
;
(
i_ho
<
(
tid_ho
+
1
)
*
num_works_ho_per_thread
)
&&
i_ho
<
num_works_ho
;
(
i_ho
<
(
tid_ho
+
1
)
*
num_works_ho_per_thread
)
&&
i_ho
<
num_works_ho
;
i_ho
+=
1
)
i_ho
+=
1
)
{
{
// for input
// for input
ck
::
index
_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
intptr
_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
for
(
ck
::
index
_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
for
(
intptr
_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
&&
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
&&
i_wo
<
Wo
;
i_wo
<
Wo
;
i_wo
+=
m_per_thread
)
i_wo
+=
m_per_thread
)
{
{
ck
::
index
_t
current_wo_size_no_dx
=
intptr
_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
m_per_thread
);
ck
::
math
::
min
(
Wo
-
i_wo
,
(
intptr_t
)
m_per_thread
);
ck
::
index
_t
i_wi_no_x
=
i_wo
*
Sx
-
Px
;
intptr
_t
i_wi_no_x
=
i_wo
*
Sx
-
Px
;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi,
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi,
// Wi);fflush(stdout);
// Wi);fflush(stdout);
for
(
ck
::
index
_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
for
(
intptr
_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
i_k
<
(
tid_k
+
1
)
*
num_works_k_per_thread
*
n_per_thread
;
i_k
<
(
tid_k
+
1
)
*
num_works_k_per_thread
*
n_per_thread
;
i_k
+=
n_per_thread
)
i_k
+=
n_per_thread
)
{
{
ck
::
index
_t
i_dx
=
0
;
intptr
_t
i_dx
=
0
;
ck
::
index
_t
i_dy
=
0
;
intptr
_t
i_dy
=
0
;
bool
accmulate_c
=
false
;
bool
accmulate_c
=
false
;
ck
::
index_t
current_k_size
=
ck
::
math
::
min
(
K
-
i_k
,
n_per_thread
);
intptr_t
current_k_size
=
ck
::
math
::
min
(
K
-
i_k
,
(
intptr_t
)
n_per_thread
);
auto
accumulate_dy_dx
=
[
&
]()
{
auto
accumulate_dy_dx
=
[
&
]()
{
i_dx
+=
Dx
;
i_dx
+=
Dx
;
...
@@ -589,25 +590,25 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -589,25 +590,25 @@ struct GridwiseDirectConvNHWCAvx2
}
}
};
};
for
(
ck
::
index
_t
i_yxc
=
0
;
i_yxc
<
(
Y
*
X
*
C
);
for
(
intptr
_t
i_yxc
=
0
;
i_yxc
<
(
Y
*
X
*
C
);
i_yxc
+=
C
,
accumulate_dy_dx
())
i_yxc
+=
C
,
accumulate_dy_dx
())
{
{
ck
::
index
_t
current_i_wo
=
i_wo
;
intptr
_t
current_i_wo
=
i_wo
;
ck
::
index
_t
i_hi
=
i_hi_no_y
+
i_dy
;
intptr
_t
i_hi
=
i_hi_no_y
+
i_dy
;
if
(
i_hi
<
0
||
i_hi
>=
Hi
)
if
(
i_hi
<
0
||
i_hi
>=
Hi
)
continue
;
continue
;
ck
::
index
_t
i_wi
=
i_wi_no_x
+
i_dx
;
intptr
_t
i_wi
=
i_wi_no_x
+
i_dx
;
ck
::
index
_t
current_wo_size
=
current_wo_size_no_dx
;
intptr
_t
current_wo_size
=
current_wo_size_no_dx
;
ck
::
index
_t
pad_wo_size
=
0
;
// when left pad, we may never have
intptr
_t
pad_wo_size
=
0
;
// when left pad, we may never have
// a chance to clear zero (like
// a chance to clear zero (like
// padding) we need to manually clear that
// padding) we need to manually clear that
if
(
i_wi
<
0
)
if
(
i_wi
<
0
)
{
{
ck
::
index
_t
wi_to_zero_length
=
intptr
_t
wi_to_zero_length
=
-
i_wi
;
// keep this a possitive number
-
i_wi
;
// keep this a possitive number
ck
::
index
_t
steps_wo_turn_possitive
=
intptr
_t
steps_wo_turn_possitive
=
(
wi_to_zero_length
+
Sx
-
1
)
/
(
wi_to_zero_length
+
Sx
-
1
)
/
Sx
;
// how many steps need to move wo, to let wi to be
Sx
;
// how many steps need to move wo, to let wi to be
// possitive
// possitive
...
@@ -647,7 +648,7 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -647,7 +648,7 @@ struct GridwiseDirectConvNHWCAvx2
if
(
pad_wo_size
!=
0
)
if
(
pad_wo_size
!=
0
)
{
{
for
(
ck
::
index
_t
i_wo_pad
=
0
;
i_wo_pad
<
pad_wo_size
;
for
(
intptr
_t
i_wo_pad
=
0
;
i_wo_pad
<
pad_wo_size
;
i_wo_pad
++
)
i_wo_pad
++
)
{
{
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
...
@@ -747,28 +748,28 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -747,28 +748,28 @@ struct GridwiseDirectConvNHWCAvx2
tid
/=
num_threads_wo
;
tid
/=
num_threads_wo
;
const
ck
::
index_t
tid_k
=
tid
;
const
ck
::
index_t
tid_k
=
tid
;
for
(
ck
::
index
_t
i_n
=
tid_n
*
num_works_n_per_thread
;
for
(
intptr
_t
i_n
=
tid_n
*
num_works_n_per_thread
;
(
i_n
<
(
tid_n
+
1
)
*
num_works_n_per_thread
)
&&
i_n
<
num_works_n
;
(
i_n
<
(
tid_n
+
1
)
*
num_works_n_per_thread
)
&&
i_n
<
num_works_n
;
i_n
+=
1
)
i_n
+=
1
)
{
{
for
(
ck
::
index
_t
i_ho
=
tid_ho
*
num_works_ho_per_thread
;
for
(
intptr
_t
i_ho
=
tid_ho
*
num_works_ho_per_thread
;
(
i_ho
<
(
tid_ho
+
1
)
*
num_works_ho_per_thread
)
&&
i_ho
<
num_works_ho
;
(
i_ho
<
(
tid_ho
+
1
)
*
num_works_ho_per_thread
)
&&
i_ho
<
num_works_ho
;
i_ho
+=
1
)
i_ho
+=
1
)
{
{
// for input
// for input
ck
::
index
_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
intptr
_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
for
(
ck
::
index
_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
for
(
intptr
_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
&&
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
&&
i_wo
<
Wo
;
i_wo
<
Wo
;
i_wo
+=
m_per_thread
)
i_wo
+=
m_per_thread
)
{
{
ck
::
index
_t
current_wo_size_no_dx
=
intptr
_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
m_per_thread
);
ck
::
math
::
min
(
Wo
-
i_wo
,
(
intptr_t
)
m_per_thread
);
ck
::
index
_t
i_wi_no_x
=
i_wo
*
Sx
-
Px
;
intptr
_t
i_wi_no_x
=
i_wo
*
Sx
-
Px
;
ck
::
index
_t
i_dx
=
0
;
intptr
_t
i_dx
=
0
;
ck
::
index
_t
i_dy
=
0
;
intptr
_t
i_dy
=
0
;
bool
accmulate_c
=
false
;
bool
accmulate_c
=
false
;
// printf("-- [%d] i_n:%d, i_ho:%d, i_wo:%d, num_works_n:%d,
// printf("-- [%d] i_n:%d, i_ho:%d, i_wo:%d, num_works_n:%d,
// num_threads_n:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
// num_threads_n:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
...
@@ -785,19 +786,19 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -785,19 +786,19 @@ struct GridwiseDirectConvNHWCAvx2
}
}
};
};
for
(
ck
::
index
_t
i_yxc
=
0
;
i_yxc
<
(
Y
*
X
*
C
);
for
(
intptr
_t
i_yxc
=
0
;
i_yxc
<
(
Y
*
X
*
C
);
i_yxc
+=
C
,
accumulate_dy_dx
())
i_yxc
+=
C
,
accumulate_dy_dx
())
{
{
ck
::
index
_t
current_i_wo
=
i_wo
;
intptr
_t
current_i_wo
=
i_wo
;
ck
::
index
_t
i_hi
=
i_hi_no_y
+
i_dy
;
intptr
_t
i_hi
=
i_hi_no_y
+
i_dy
;
bool
run_pad_only
=
false
;
bool
run_pad_only
=
false
;
if
(
i_hi
<
0
||
i_hi
>=
Hi
)
if
(
i_hi
<
0
||
i_hi
>=
Hi
)
continue
;
continue
;
ck
::
index
_t
i_wi
=
i_wi_no_x
+
i_dx
;
intptr
_t
i_wi
=
i_wi_no_x
+
i_dx
;
ck
::
index
_t
current_wo_size
=
current_wo_size_no_dx
;
intptr
_t
current_wo_size
=
current_wo_size_no_dx
;
ck
::
index
_t
pad_wo_size
=
0
;
// when left pad, we may never have a
intptr
_t
pad_wo_size
=
0
;
// when left pad, we may never have a
// chance to clear zero (like
// chance to clear zero (like
// padding) we need to manually clear that
// padding) we need to manually clear that
/* left corner shift
/* left corner shift
...
@@ -812,9 +813,9 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -812,9 +813,9 @@ struct GridwiseDirectConvNHWCAvx2
*/
*/
if
(
i_wi
<
0
)
if
(
i_wi
<
0
)
{
{
ck
::
index
_t
wi_to_zero_length
=
intptr
_t
wi_to_zero_length
=
-
i_wi
;
// keep this a possitive number
-
i_wi
;
// keep this a possitive number
ck
::
index
_t
steps_wo_turn_possitive
=
intptr
_t
steps_wo_turn_possitive
=
(
wi_to_zero_length
+
Sx
-
1
)
/
(
wi_to_zero_length
+
Sx
-
1
)
/
Sx
;
// how many steps need to move wo, to let wi to be
Sx
;
// how many steps need to move wo, to let wi to be
// possitive
// possitive
...
@@ -859,9 +860,9 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -859,9 +860,9 @@ struct GridwiseDirectConvNHWCAvx2
{
{
// manually clear zero. this may and only may need once along
// manually clear zero. this may and only may need once along
// the gemm_k reduction
// the gemm_k reduction
ck
::
index
_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
intptr
_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
ck
::
index
_t
current_k_block_size
=
ck
::
math
::
min
(
intptr
_t
current_k_block_size
=
ck
::
math
::
min
(
K
-
i_k
,
num_works_k_per_thread
*
n_per_thread
);
K
-
i_k
,
(
intptr_t
)
num_works_k_per_thread
*
n_per_thread
);
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
(
i_n
*
Ho
+
i_ho
)
*
Wo
,
i_k
);
c_grid_desc
,
(
i_n
*
Ho
+
i_ho
)
*
Wo
,
i_k
);
...
@@ -879,12 +880,12 @@ struct GridwiseDirectConvNHWCAvx2
...
@@ -879,12 +880,12 @@ struct GridwiseDirectConvNHWCAvx2
if
(
run_pad_only
)
if
(
run_pad_only
)
continue
;
continue
;
for
(
ck
::
index
_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
for
(
intptr
_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
i_k
<
(
tid_k
+
1
)
*
num_works_k_per_thread
*
n_per_thread
;
i_k
<
(
tid_k
+
1
)
*
num_works_k_per_thread
*
n_per_thread
;
i_k
+=
n_per_thread
)
i_k
+=
n_per_thread
)
{
{
ck
::
index
_t
current_k_size
=
intptr
_t
current_k_size
=
ck
::
math
::
min
(
K
-
i_k
,
n_per_thread
);
ck
::
math
::
min
(
K
-
i_k
,
(
intptr_t
)
n_per_thread
);
const
intptr_t
offset_a
=
current_input_offset
;
const
intptr_t
offset_a
=
current_input_offset
;
const
intptr_t
offset_b
=
const
intptr_t
offset_b
=
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
05d38218
...
@@ -1054,7 +1054,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
...
@@ -1054,7 +1054,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for
(
in
dex
_t
i_n_itr
=
0
;
i_n_itr
<
n_per_block
;
i_n_itr
+=
8
)
for
(
in
tptr
_t
i_n_itr
=
0
;
i_n_itr
<
n_per_block
;
i_n_itr
+=
8
)
{
{
intptr_t
current_n_8
=
ck
::
math
::
min
(
GemmN
-
(
i_n_itr
+
i_gemm_n
),
(
intptr_t
)
8
);
intptr_t
current_n_8
=
ck
::
math
::
min
(
GemmN
-
(
i_n_itr
+
i_gemm_n
),
(
intptr_t
)
8
);
intptr_t
i_k_itr
=
k_per_block
;
intptr_t
i_k_itr
=
k_per_block
;
...
@@ -1150,9 +1150,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
...
@@ -1150,9 +1150,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
const
float
*
p_src_k
=
p_src
;
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
float
*
p_dst_k
=
p_dst
;
for
(
in
dex
_t
i_sub_n
=
0
;
i_sub_n
<
8
;
i_sub_n
++
)
for
(
in
tptr
_t
i_sub_n
=
0
;
i_sub_n
<
8
;
i_sub_n
++
)
{
{
for
(
in
dex
_t
i_sub_k
=
0
;
i_sub_k
<
k_per_block
;
i_sub_k
++
)
for
(
in
tptr
_t
i_sub_k
=
0
;
i_sub_k
<
k_per_block
;
i_sub_k
++
)
{
{
intptr_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
intptr_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
...
@@ -1269,7 +1269,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
...
@@ -1269,7 +1269,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// n0 * k * n1
// n0 * k * n1
in
dex
_t
i_n0_itr
=
n0_per_block
;
in
tptr
_t
i_n0_itr
=
n0_per_block
;
while
(
i_n0_itr
>=
8
)
while
(
i_n0_itr
>=
8
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_n1_per_block
,
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_n1_per_block
,
...
@@ -1440,7 +1440,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
...
@@ -1440,7 +1440,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// k * n
// k * n
in
dex
_t
i_k_itr
=
k_per_block
;
in
tptr
_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>=
8
)
while
(
i_k_itr
>=
8
)
{
{
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment