Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ce87bcc7
Commit
ce87bcc7
authored
Sep 08, 2023
by
Bartlomiej Kocot
Browse files
tmp
parent
c8a8385f
Changes
123
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
94 additions
and
124 deletions
+94
-124
profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp
...ude/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp
+4
-9
profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp
...er/include/profiler/profile_gemm_bias_add_reduce_impl.hpp
+6
-3
profiler/include/profiler/profile_gemm_bilinear_impl.hpp
profiler/include/profiler/profile_gemm_bilinear_impl.hpp
+3
-3
profiler/include/profiler/profile_gemm_fastgelu_impl.hpp
profiler/include/profiler/profile_gemm_fastgelu_impl.hpp
+3
-3
profiler/include/profiler/profile_gemm_impl.hpp
profiler/include/profiler/profile_gemm_impl.hpp
+3
-3
profiler/include/profiler/profile_gemm_reduce_impl.hpp
profiler/include/profiler/profile_gemm_reduce_impl.hpp
+7
-4
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+5
-5
profiler/include/profiler/profile_gemm_streamk_impl.hpp
profiler/include/profiler/profile_gemm_streamk_impl.hpp
+3
-3
profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp
...r/include/profiler/profile_grouped_conv_bwd_data_impl.hpp
+3
-3
profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
...include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
+4
-5
profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp
profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp
+3
-3
profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp
...r/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp
+4
-5
profiler/include/profiler/profile_grouped_gemm_impl.hpp
profiler/include/profiler/profile_grouped_gemm_impl.hpp
+4
-6
profiler/include/profiler/profile_groupnorm_impl.hpp
profiler/include/profiler/profile_groupnorm_impl.hpp
+3
-11
profiler/include/profiler/profile_layernorm_impl.hpp
profiler/include/profiler/profile_layernorm_impl.hpp
+3
-10
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
+4
-11
profiler/include/profiler/profile_reduce_impl.hpp
profiler/include/profiler/profile_reduce_impl.hpp
+4
-11
profiler/include/profiler/profile_softmax_impl.hpp
profiler/include/profiler/profile_softmax_impl.hpp
+4
-5
test/conv_util/conv_util.cpp
test/conv_util/conv_util.cpp
+18
-15
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+6
-6
No files found.
profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp
View file @
ce87bcc7
...
...
@@ -250,7 +250,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
int
num_kernel
=
0
;
// profile device operation instances
...
...
@@ -316,7 +316,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
{
h_device_buf
.
FromDevice
(
h_m_n
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
validator
.
check_err
(
h_m_n
,
h_m_n_host
,
"Error: Incorrect results h_m_n"
,
1e-2
,
1e-2
);
}
}
...
...
@@ -327,19 +327,14 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
}
}
if
(
num_kernel
==
0
)
{
std
::
cout
<<
"Error: No kernel is applicable"
<<
std
::
endl
;
pass
=
false
;
}
else
if
(
num_kernel
!=
0
)
{
if
(
time_kernel
)
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
}
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp
View file @
ce87bcc7
...
...
@@ -281,6 +281,8 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
...
...
@@ -343,9 +345,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
reduce0_device_buf
.
FromDevice
(
reduce0_m_device_result
.
mData
.
data
());
reduce1_device_buf
.
FromDevice
(
reduce1_m_device_result
.
mData
.
data
());
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
ck
::
utils
::
check_err
(
reduce0_m_device_result
,
reduce0_m_host_result
);
ck
::
utils
::
check_err
(
reduce1_m_device_result
,
reduce1_m_host_result
);
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
validator
.
check_err
(
reduce0_m_device_result
,
reduce0_m_host_result
);
validator
.
check_err
(
reduce1_m_device_result
,
reduce1_m_host_result
);
if
(
do_log
)
{
...
...
@@ -376,6 +378,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
}
}
validator
.
is_success
();
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
}
...
...
profiler/include/profiler/profile_gemm_bilinear_impl.hpp
View file @
ce87bcc7
...
...
@@ -158,7 +158,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device operation instances
for
(
auto
&
op_ptr
:
op_ptrs
)
...
...
@@ -215,7 +215,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
{
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
validator
.
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
}
}
else
...
...
@@ -227,7 +227,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_fastgelu_impl.hpp
View file @
ce87bcc7
...
...
@@ -147,7 +147,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device operation instances
for
(
auto
&
op_ptr
:
op_ptrs
)
...
...
@@ -203,7 +203,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
{
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
validator
.
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
}
}
else
...
...
@@ -215,7 +215,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_impl.hpp
View file @
ce87bcc7
...
...
@@ -42,7 +42,7 @@ int profile_gemm_impl(int do_verification,
int
StrideB
,
int
StrideC
)
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -188,7 +188,7 @@ int profile_gemm_impl(int do_verification,
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
if
(
do_log
)
{
...
...
@@ -247,7 +247,7 @@ int profile_gemm_impl(int do_verification,
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_reduce_impl.hpp
View file @
ce87bcc7
...
...
@@ -250,6 +250,8 @@ bool profile_gemm_reduce_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
...
...
@@ -310,9 +312,10 @@ bool profile_gemm_reduce_impl(int do_verification,
reduce0_device_buf
.
FromDevice
(
reduce0_m_device_result
.
mData
.
data
());
reduce1_device_buf
.
FromDevice
(
reduce1_m_device_result
.
mData
.
data
());
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
ck
::
utils
::
check_err
(
reduce0_m_device_result
,
reduce0_m_host_result
);
ck
::
utils
::
check_err
(
reduce1_m_device_result
,
reduce1_m_host_result
);
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
validator
.
check_err
(
reduce0_m_device_result
,
reduce0_m_host_result
);
validator
.
check_err
(
reduce1_m_device_result
,
reduce1_m_host_result
);
validator
.
is_success
();
if
(
do_log
)
{
...
...
@@ -346,7 +349,7 @@ bool profile_gemm_reduce_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
ce87bcc7
...
...
@@ -43,7 +43,7 @@ bool profile_gemm_splitk_impl(int do_verification,
int
StrideC
,
int
KBatch
)
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -181,7 +181,7 @@ bool profile_gemm_splitk_impl(int do_verification,
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
if
(
do_log
)
{
...
...
@@ -221,12 +221,12 @@ bool profile_gemm_splitk_impl(int do_verification,
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
atol
=
1e-1
;
pass
=
pass
&
ck
::
utils
::
check_err
(
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
msg
,
rtol
,
atol
);
}
else
{
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
}
if
(
tflops
>
best_tflops
)
...
...
@@ -286,7 +286,7 @@ bool profile_gemm_splitk_impl(int do_verification,
<<
" : "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_streamk_impl.hpp
View file @
ce87bcc7
...
...
@@ -43,7 +43,7 @@ bool profile_gemm_streamk_impl(int do_verification,
int
StrideC
,
uint32_t
NumSKBlocks
=
0xffffffff
)
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -176,7 +176,7 @@ bool profile_gemm_streamk_impl(int do_verification,
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
if
(
do_log
)
{
...
...
@@ -260,7 +260,7 @@ bool profile_gemm_streamk_impl(int do_verification,
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp
View file @
ce87bcc7
...
...
@@ -122,7 +122,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
float
best_gb_per_sec
=
0
;
// profile device op instances
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
run_impl
=
[
&
](
auto
&
op_ptr
,
auto
&
argument_ptr
)
{
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
...
...
@@ -159,7 +159,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
{
in_device_buf
.
FromDevice
(
in_device
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
in_device
,
in_host
);
validator
.
check_err
(
in_device
,
in_host
);
if
(
do_log
)
{
...
...
@@ -250,7 +250,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
<<
"
\n
name: "
<<
best_op_name
<<
"
\n
avg_time: "
<<
best_avg_time
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
View file @
ce87bcc7
...
...
@@ -160,6 +160,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
range_copy
(
conv_param
.
input_right_pads_
,
begin
(
input_right_pads
));
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
op_ptr
:
op_ptrs
)
{
auto
argument_ptr
=
...
...
@@ -214,15 +215,13 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
{
wei_device_buf
.
FromDevice
(
weight_device_result
.
mData
.
data
());
bool
pass
=
ck
::
utils
::
check_err
(
weight_device_result
,
weight_host_result
);
validator
.
check_err
(
weight_device_result
,
weight_host_result
);
if
(
!
pass
)
if
(
!
validator
.
is_success
()
)
{
std
::
cout
<<
"Fail info: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
all_pass
&=
pass
;
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"output : "
,
output
.
mData
,
","
)
<<
std
::
endl
;
...
...
@@ -250,7 +249,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
<<
"
\n
name: "
<<
best_op_name
<<
"
\n
avg_time: "
<<
best_avg_time
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
std
::
endl
;
return
al
l_pass
;
return
v
al
idator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp
View file @
ce87bcc7
...
...
@@ -142,7 +142,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
float
best_gb_per_sec
=
0
;
// profile device op instances
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
run_impl
=
[
&
](
auto
&
op_ptr
,
auto
&
argument_ptr
)
{
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
...
...
@@ -179,7 +179,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
{
out_device_buf
.
FromDevice
(
device_output
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
device_output
,
host_output
);
validator
.
check_err
(
device_output
,
host_output
);
if
(
do_log
)
{
...
...
@@ -246,7 +246,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
<<
"
\n
name: "
<<
best_op_name
<<
"
\n
avg_time: "
<<
best_avg_time
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp
View file @
ce87bcc7
...
...
@@ -39,7 +39,7 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
const
std
::
vector
<
int
>&
StrideCs
)
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -238,8 +238,7 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
ref_invoker
.
Run
(
ref_argument
);
bool
group_pass
=
ck
::
utils
::
check_err
(
c_m_n_device_results
[
i
],
c_m_n_host_result
);
pass
=
pass
&&
group_pass
;
validator
.
check_err
(
c_m_n_device_results
[
i
],
c_m_n_host_result
);
std
::
cout
<<
"group: "
<<
i
<<
" verification result: "
<<
std
::
boolalpha
<<
group_pass
<<
std
::
endl
;
...
...
@@ -267,13 +266,13 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
if
(
do_verification
)
{
std
::
cout
<<
"Verification: "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
std
::
cout
<<
"Verification: "
<<
(
validator
.
is_success
()
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_grouped_gemm_impl.hpp
View file @
ce87bcc7
...
...
@@ -44,7 +44,7 @@ bool profile_grouped_gemm_impl(int do_verification,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
)
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -274,7 +274,7 @@ bool profile_grouped_gemm_impl(int do_verification,
if
(
std
::
is_same_v
<
CDataType
,
ck
::
half_t
>
&&
kbatch_curr
>
1
)
{
instance_pass
=
instance_pass
&&
ck
::
utils
::
check_err
(
c_m_n_device_results
[
i
],
instance_pass
&&
validator
.
check_err
(
c_m_n_device_results
[
i
],
c_m_n_host_results
[
i
],
"Error: Incorrect results!"
,
0.06
);
...
...
@@ -282,7 +282,7 @@ bool profile_grouped_gemm_impl(int do_verification,
else
{
instance_pass
=
instance_pass
&&
ck
::
utils
::
check_err
(
c_m_n_device_results
[
i
],
instance_pass
&&
validator
.
check_err
(
c_m_n_device_results
[
i
],
c_m_n_host_results
[
i
]);
}
...
...
@@ -303,8 +303,6 @@ bool profile_grouped_gemm_impl(int do_verification,
std
::
cout
<<
"Instance: "
<<
gemm_name
<<
" verification "
<<
(
instance_pass
?
"SUCCEED"
:
"FAILED"
)
<<
std
::
endl
;
pass
=
pass
&&
instance_pass
;
}
float
ave_time
=
...
...
@@ -354,7 +352,7 @@ bool profile_grouped_gemm_impl(int do_verification,
<<
std
::
endl
;
}
return
pass
;
return
validator
.
is_success
()
;
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_groupnorm_impl.hpp
View file @
ce87bcc7
...
...
@@ -110,7 +110,7 @@ bool profile_groupnorm_impl(int do_verification,
ref_invoker
.
Run
(
ref_argument
);
}
int
num_kernel
=
0
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
inst_ptr
:
instance_ptrs
)
{
...
...
@@ -169,7 +169,7 @@ bool profile_groupnorm_impl(int do_verification,
{
y_dev
.
FromDevice
(
y
.
mData
.
data
());
bool
pass
=
ck
::
utils
::
check_err
(
y
,
host_y
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
bool
pass
=
validator
.
check_err
(
y
,
host_y
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
if
(
do_log
)
{
...
...
@@ -182,7 +182,6 @@ bool profile_groupnorm_impl(int do_verification,
{
std
::
cout
<<
inst_ptr
->
GetTypeString
()
<<
" failed verification: "
;
LogRange
(
std
::
cout
<<
"lengths = ["
,
length
,
", "
)
<<
"]."
<<
std
::
endl
;
return
false
;
}
else
{
...
...
@@ -198,14 +197,7 @@ bool profile_groupnorm_impl(int do_verification,
std
::
cout
<<
"best perf = "
<<
best_avg_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_instance_name
<<
std
::
endl
;
}
if
(
num_kernel
==
0
)
{
std
::
cout
<<
"Error: No kernel is applicable"
<<
std
::
endl
;
return
false
;
}
return
true
;
return
validator
.
is_success
();
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_layernorm_impl.hpp
View file @
ce87bcc7
...
...
@@ -121,7 +121,7 @@ bool profile_layernorm_impl(int do_verification,
ref_invoker
.
Run
(
ref_argument
);
}
int
num_kernel
=
0
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
inst_ptr
:
instance_ptrs
)
{
...
...
@@ -186,7 +186,7 @@ bool profile_layernorm_impl(int do_verification,
y_dev
.
FromDevice
(
y
.
mData
.
data
());
bool
pass
=
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
validator
.
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
if
(
do_log
)
{
...
...
@@ -199,7 +199,6 @@ bool profile_layernorm_impl(int do_verification,
{
std
::
cout
<<
inst_ptr
->
GetTypeString
()
<<
" failed verification: "
;
LogRange
(
std
::
cout
<<
"lengths = ["
,
length
,
", "
)
<<
"]."
<<
std
::
endl
;
return
false
;
}
else
{
...
...
@@ -218,13 +217,7 @@ bool profile_layernorm_impl(int do_verification,
<<
best_instance_name
<<
std
::
endl
;
}
if
(
num_kernel
==
0
)
{
std
::
cout
<<
"Error: No kernel is applicable"
<<
std
::
endl
;
return
false
;
}
return
true
;
return
validator
.
is_success
();
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
View file @
ce87bcc7
...
...
@@ -150,7 +150,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
ref_invoker
.
Run
(
ref_argument
);
}
int
num_kernel
=
0
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
inst_ptr
:
instance_ptrs
)
{
...
...
@@ -213,7 +213,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
{
out_device_buf
.
FromDevice
(
out_n_c_do_ho_wo_device
.
mData
.
data
());
bool
pass
=
ck
::
utils
::
check_err
(
out_n_c_do_ho_wo_device
.
mData
,
bool
pass
=
validator
.
check_err
(
out_n_c_do_ho_wo_device
.
mData
,
out_n_c_do_ho_wo_host
.
mData
,
"Error: Incorrect results"
,
1e-3
,
...
...
@@ -223,7 +223,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
{
out_indices_device_buf
.
FromDevice
(
out_indices_n_c_do_ho_wo_device
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
out_indices_n_c_do_ho_wo_device
,
pass
=
pass
&&
validator
.
check_err
(
out_indices_n_c_do_ho_wo_device
,
out_indices_n_c_do_ho_wo_host
);
}
...
...
@@ -250,7 +250,6 @@ bool profile_pool3d_fwd_impl(int do_verification,
{
std
::
cout
<<
inst_ptr
->
GetTypeString
()
<<
" failed verification: "
;
LogRange
(
std
::
cout
<<
"lengths = ["
,
in_length
,
", "
)
<<
"]."
<<
std
::
endl
;
return
false
;
}
else
{
...
...
@@ -267,13 +266,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
<<
best_instance_name
<<
std
::
endl
;
}
if
(
num_kernel
==
0
)
{
std
::
cout
<<
"Error: No kernel is applicable"
<<
std
::
endl
;
return
false
;
}
return
true
;
return
validator
.
is_success
();
}
}
// namespace profiler
...
...
profiler/include/profiler/profile_reduce_impl.hpp
View file @
ce87bcc7
...
...
@@ -196,7 +196,7 @@ bool profile_reduce_impl_impl(bool do_verification,
invalid_reduce_4
||
invalid_reduce_5
||
invalid_reduce_6
);
int
num_kernel
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
constexpr
(
!
invalid_reduce
)
{
...
...
@@ -403,12 +403,12 @@ bool profile_reduce_impl_impl(bool do_verification,
bool
single_pass
;
out_dev
.
FromDevice
(
out
.
mData
.
data
());
single_pass
=
ck
::
utils
::
check_err
(
out
,
out_ref
);
single_pass
=
validator
.
check_err
(
out
,
out_ref
);
if
(
OutputIndex
)
{
out_indices_dev
.
FromDevice
(
out_indices
.
mData
.
data
());
single_pass
=
single_pass
&&
ck
::
utils
::
check_err
(
out_indices
,
out_indices_ref
);
single_pass
=
single_pass
&&
validator
.
check_err
(
out_indices
,
out_indices_ref
);
};
if
(
!
single_pass
)
...
...
@@ -416,7 +416,6 @@ bool profile_reduce_impl_impl(bool do_verification,
std
::
cout
<<
"Fail Info: "
<<
reduce_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
pass
=
pass
&&
single_pass
;
};
if
(
do_dumpout
)
...
...
@@ -447,13 +446,7 @@ bool profile_reduce_impl_impl(bool do_verification,
"The requested reduction operation is not supported, please check!"
);
};
if
(
num_kernel
==
0
)
{
std
::
cout
<<
"Error: No kernel is applicable"
<<
std
::
endl
;
return
false
;
};
return
pass
;
return
validator
.
is_success
();
};
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
>
...
...
profiler/include/profiler/profile_softmax_impl.hpp
View file @
ce87bcc7
...
...
@@ -123,7 +123,7 @@ bool profile_softmax_impl(int do_verification,
std
::
string
best_instance_name
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
std
::
vector
<
bool
>
instance_pass
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
inst_ptr
:
instances
)
{
...
...
@@ -176,7 +176,7 @@ bool profile_softmax_impl(int do_verification,
bool
pass
=
true
;
if
(
std
::
is_same
<
InDataType
,
int8_t
>::
value
)
{
pass
=
pass
&&
ck
::
utils
::
check_err
(
pass
=
pass
&&
validator
.
check_err
(
out
.
mData
,
out_ref
.
mData
,
"Error: Incorrect results!"
,
0
,
1
);
if
(
do_log
)
{
...
...
@@ -188,7 +188,7 @@ bool profile_softmax_impl(int do_verification,
}
else
{
pass
=
pass
&&
ck
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
pass
=
pass
&&
validator
.
check_err
(
out
.
mData
,
out_ref
.
mData
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in
.
mData
,
","
)
<<
std
::
endl
;
...
...
@@ -219,8 +219,7 @@ bool profile_softmax_impl(int do_verification,
<<
"beta = "
<<
beta
<<
", "
<<
best_avg_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_instance_name
<<
std
::
endl
;
}
return
std
::
all_of
(
std
::
begin
(
instance_pass
),
std
::
end
(
instance_pass
),
[](
bool
p
)
{
return
p
;
});
return
validator
.
is_success
();
}
}
// namespace profiler
...
...
test/conv_util/conv_util.cpp
View file @
ce87bcc7
...
...
@@ -46,110 +46,113 @@ class TestConvUtil : public ::testing::Test
TEST_F
(
TestConvUtil
,
ConvParamsGetOutputSpatialLengths1D
)
{
ck
::
utils
::
CorrectnessValidator
validator
;
// stride 2, dilation 1, pad 1
SetNDParams
(
1
,
2
,
1
,
1
);
std
::
vector
<
ck
::
index_t
>
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
36
},
"Error: ConvParams 1D."
));
// stride 1, dilation 1, pad 1
SetNDParams
(
1
,
1
,
1
,
1
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
71
},
"Error: ConvParams 1D stride {1}."
));
// stride 2, dilation 1, pad 2
SetNDParams
(
1
,
2
,
1
,
2
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
out_spatial_len
,
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
37
},
"Error: ConvParams 1D padding left/right {2}."
));
// stride 2, dilation 2, pad 2
SetNDParams
(
1
,
2
,
2
,
2
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
36
},
"Error: ConvParams 1D dilation {2}."
));
// stride 3, dilation 2, pad 1
SetNDParams
(
1
,
3
,
2
,
1
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
out_spatial_len
,
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
23
},
"Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."
));
}
TEST_F
(
TestConvUtil
,
ConvParamsGetOutputSpatialLengths2D
)
{
ck
::
utils
::
CorrectnessValidator
validator
;
// stride 2, dilation 1, pad 1
SetNDParams
(
2
,
2
,
1
,
1
);
std
::
vector
<
ck
::
index_t
>
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
out_spatial_len
,
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
36
,
36
},
"Error: ConvParams 2D default constructor."
));
// stride 1, dilation 1, pad 1
SetNDParams
(
2
,
1
,
1
,
1
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
71
,
71
},
"Error: ConvParams 2D stride {1,1}."
));
// stride 2, dilation 1, pad 2
SetNDParams
(
2
,
2
,
1
,
2
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
out_spatial_len
,
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
37
,
37
},
"Error: ConvParams 2D padding left/right {2,2}."
));
// stride 2, dilation 2, pad 2
SetNDParams
(
2
,
2
,
2
,
2
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
36
,
36
},
"Error: ConvParams 2D dilation {2,2}."
));
// stride 3, dilation 2, pad 1
SetNDParams
(
2
,
3
,
2
,
1
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
out_spatial_len
,
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
23
,
23
},
"Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."
));
}
TEST_F
(
TestConvUtil
,
ConvParamsGetOutputSpatialLengths3D
)
{
ck
::
utils
::
CorrectnessValidator
validator
;
// stride 2, dilation 1, pad 1
SetNDParams
(
3
,
2
,
1
,
1
);
std
::
vector
<
ck
::
index_t
>
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
36
,
36
,
36
},
"Error: ConvParams 3D."
));
// stride 1, dilation 1, pad 1
SetNDParams
(
3
,
1
,
1
,
1
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
out_spatial_len
,
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
71
,
71
,
71
},
"Error: ConvParams 3D stride {1, 1, 1}."
));
// stride 2, dilation 1, pad 2
SetNDParams
(
3
,
2
,
1
,
2
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
out_spatial_len
,
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
37
,
37
,
37
},
"Error: ConvParams 3D padding left/right {2, 2, 2}."
));
// stride 2, dilation 2, pad 2
SetNDParams
(
3
,
2
,
2
,
2
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
out_spatial_len
,
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
36
,
36
,
36
},
"Error: ConvParams 3D dilation {2, 2, 2}."
));
// stride 3, dilation 2, pad 1
SetNDParams
(
3
,
3
,
2
,
1
);
out_spatial_len
=
conv_params
.
GetOutputSpatialLengths
();
EXPECT_TRUE
(
ck
::
utils
::
check_err
(
EXPECT_TRUE
(
validator
.
check_err
(
out_spatial_len
,
std
::
vector
<
ck
::
index_t
>
{
23
,
23
,
23
},
"Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."
));
...
...
test/gemm/gemm_util.hpp
View file @
ce87bcc7
...
...
@@ -227,30 +227,30 @@ struct TestGemm
if
(
is_supported
&&
do_verification
)
{
// Assert
bool
res
=
false
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
res
=
validator
.
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
res
=
validator
.
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
res
=
validator
.
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
res
=
validator
.
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
double
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
res
=
validator
.
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment