Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0bd6d2ce
Commit
0bd6d2ce
authored
Jul 07, 2023
by
rocking
Browse files
Refactor reference code for different dimension
parent
da0ddd48
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
141 additions
and
133 deletions
+141
-133
library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
.../reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
+141
-133
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp
View file @
0bd6d2ce
...
...
@@ -58,165 +58,162 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
using
Argument
=
ReferenceAvgPoolBwd
::
Argument
;
float
Run
(
const
Argument
&
arg
)
template
<
ck
::
index_t
NDimSpatial_
,
typename
std
::
enable_if
<
NDimSpatial_
==
1
,
bool
>
::
type
=
false
>
float
RunAvgPoolBwd
(
const
Argument
&
arg
)
{
if
(
!
(
arg
.
dinput_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
&&
arg
.
doutput_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
))
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
2
];
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
0
]);
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
if
(
w_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
{
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
0
]);
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
if
(
w
_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
if
(
w
o
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
wo
));
}
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
wo
));
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
X
);
arg
.
dinput_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
v_acc
/=
ck
::
type_convert
<
float
>
(
X
);
arg
.
dinput_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
1
];
return
0
;
}
std
::
size_t
Ho
=
arg
.
doutput_
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
3
];
template
<
ck
::
index_t
NDimSpatial_
,
typename
std
::
enable_if
<
NDimSpatial_
==
2
,
bool
>
::
type
=
false
>
float
RunAvgPoolBwd
(
const
Argument
&
arg
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
Ho
=
arg
.
doutput_
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
3
];
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
{
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
window_dilations_
[
0
]);
if
(
h_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
{
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
window_dilations_
[
0
]);
if
(
h_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
1
]);
if
(
w_tmp
%
arg
.
window_strides_
[
1
]
==
0
)
{
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
1
]);
if
(
w_tmp
%
arg
.
window_strides_
[
1
]
==
0
)
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
1
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
1
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
ho
,
wo
));
}
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
ho
,
wo
));
}
}
}
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
Y
*
X
);
arg
.
dinput_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
v_acc
/=
ck
::
type_convert
<
float
>
(
Y
*
X
);
arg
.
dinput_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
],
arg
.
dinput_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
],
arg
.
dinput_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
auto
f_ncdhw
=
[
&
](
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
Z
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
2
];
std
::
size_t
Do
=
arg
.
doutput_
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
doutput_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
4
];
float
v_acc
=
0
;
return
0
;
}
for
(
std
::
size_t
z
=
0
;
z
<
Z
;
++
z
)
template
<
ck
::
index_t
NDimSpatial_
,
typename
std
::
enable_if
<
NDimSpatial_
==
3
,
bool
>
::
type
=
false
>
float
RunAvgPoolBwd
(
const
Argument
&
arg
)
{
auto
f_ncdhw
=
[
&
](
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
Z
=
arg
.
window_spatial_lengths_
[
0
];
std
::
size_t
Y
=
arg
.
window_spatial_lengths_
[
1
];
std
::
size_t
X
=
arg
.
window_spatial_lengths_
[
2
];
std
::
size_t
Do
=
arg
.
doutput_
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
doutput_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
doutput_
.
GetLengths
()[
4
];
float
v_acc
=
0
;
for
(
std
::
size_t
z
=
0
;
z
<
Z
;
++
z
)
{
auto
d_tmp
=
static_cast
<
ck
::
long_index_t
>
(
di
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
window_dilations_
[
0
]);
if
(
d_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
{
auto
d_tmp
=
static_cast
<
ck
::
long_index_t
>
(
di
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
window_dilations_
[
0
]);
if
(
d_tmp
%
arg
.
window_strides_
[
0
]
==
0
)
auto
do_
=
static_cast
<
ck
::
long_index_t
>
(
d_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
if
(
do_
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
do_
)
<
Do
)
{
auto
do_
=
static_cast
<
ck
::
long_index_t
>
(
d_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
0
]);
if
(
do_
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
do_
)
<
Do
)
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
window_dilations_
[
1
]);
if
(
h_tmp
%
arg
.
window_strides_
[
1
]
==
0
)
{
auto
h_tmp
=
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
window_dilations_
[
1
]);
if
(
h_tmp
%
arg
.
window_strides_
[
1
]
==
0
)
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
1
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
auto
ho
=
static_cast
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
1
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
2
]);
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
window_dilations_
[
2
]);
if
(
w_tmp
%
arg
.
window_strides_
[
2
]
==
0
)
if
(
w_tmp
%
arg
.
window_strides_
[
2
]
==
0
)
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
2
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
window_strides_
[
2
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
do_
,
ho
,
wo
));
}
v_acc
+=
ck
::
type_convert
<
float
>
(
arg
.
doutput_
(
n
,
c
,
do_
,
ho
,
wo
));
}
}
}
...
...
@@ -225,21 +222,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
}
}
}
}
v_acc
/=
ck
::
type_convert
<
float
>
(
Z
*
Y
*
X
);
arg
.
dinput_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
v_acc
/=
ck
::
type_convert
<
float
>
(
Z
*
Y
*
X
);
arg
.
dinput_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
DInDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncdhw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
],
arg
.
dinput_
.
GetLengths
()[
3
],
arg
.
dinput_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_ncdhw
,
arg
.
dinput_
.
GetLengths
()[
0
],
arg
.
dinput_
.
GetLengths
()[
1
],
arg
.
dinput_
.
GetLengths
()[
2
],
arg
.
dinput_
.
GetLengths
()[
3
],
arg
.
dinput_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
float
Run
(
const
Argument
&
arg
)
{
if
(
!
(
arg
.
dinput_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
&&
arg
.
doutput_
.
GetNumOfDimension
()
==
NDimSpatial
+
2
))
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
return
RunAvgPoolBwd
<
NDimSpatial
>
(
arg
);
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment