Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ce87bcc7
Commit
ce87bcc7
authored
Sep 08, 2023
by
Bartlomiej Kocot
Browse files
tmp
parent
c8a8385f
Changes
123
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
255 additions
and
241 deletions
+255
-241
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+182
-157
profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp
.../profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp
+3
-3
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
...r/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
+3
-3
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
+3
-3
profiler/include/profiler/profile_batched_gemm_impl.hpp
profiler/include/profiler/profile_batched_gemm_impl.hpp
+3
-3
profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp
...ler/include/profiler/profile_batched_gemm_reduce_impl.hpp
+5
-9
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
...clude/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
+3
-3
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
...ofiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
+3
-3
profiler/include/profiler/profile_batchnorm_backward_impl.hpp
...iler/include/profiler/profile_batchnorm_backward_impl.hpp
+6
-9
profiler/include/profiler/profile_batchnorm_forward_impl.hpp
profiler/include/profiler/profile_batchnorm_forward_impl.hpp
+9
-10
profiler/include/profiler/profile_batchnorm_infer_impl.hpp
profiler/include/profiler/profile_batchnorm_infer_impl.hpp
+5
-7
profiler/include/profiler/profile_contraction_impl.hpp
profiler/include/profiler/profile_contraction_impl.hpp
+3
-3
profiler/include/profiler/profile_conv_bwd_data_impl.hpp
profiler/include/profiler/profile_conv_bwd_data_impl.hpp
+3
-3
profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp
.../include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp
+4
-1
profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp
...iler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp
+4
-1
profiler/include/profiler/profile_conv_fwd_impl.hpp
profiler/include/profiler/profile_conv_fwd_impl.hpp
+3
-3
profiler/include/profiler/profile_elementwise_layernorm_impl.hpp
...r/include/profiler/profile_elementwise_layernorm_impl.hpp
+4
-10
profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp
...r/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp
+3
-3
profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp
profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp
+3
-3
profiler/include/profiler/profile_gemm_add_multiply_impl.hpp
profiler/include/profiler/profile_gemm_add_multiply_impl.hpp
+3
-4
No files found.
library/include/ck/library/utility/check_err.hpp
View file @
ce87bcc7
...
@@ -23,18 +23,31 @@
...
@@ -23,18 +23,31 @@
namespace
ck
{
namespace
ck
{
namespace
utils
{
namespace
utils
{
template
<
typename
Range
,
typename
RefRange
>
struct
CorrectnessValidator
{
typename
std
::
enable_if
<
public:
CorrectnessValidator
(
bool
pass_if_no_instance
=
false
)
:
pass_if_no_instance_
{
pass_if_no_instance
},
found_supporting_instance_
{
false
},
correct_results_
{
true
}
{
}
bool
is_success
()
{
return
(
pass_if_no_instance_
||
found_supporting_instance_
)
&&
correct_results_
;
}
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_floating_point_v
<
ranges
::
range_value_t
<
Range
>>
&&
std
::
is_floating_point_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
bool
>::
type
bool
>::
type
check_err
(
const
Range
&
out
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-5
,
double
rtol
=
1e-5
,
double
atol
=
3e-6
)
double
atol
=
3e-6
)
{
{
found_supporting_instance_
=
true
;
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
...
@@ -67,20 +80,22 @@ check_err(const Range& out,
...
@@ -67,20 +80,22 @@ check_err(const Range& out,
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
correct_results_
=
correct_results_
&&
res
;
return
res
;
return
res
;
}
}
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
,
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
,
bool
>::
type
bool
>::
type
check_err
(
const
Range
&
out
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
double
atol
=
1e-3
)
{
{
found_supporting_instance_
=
true
;
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
...
@@ -114,20 +129,22 @@ check_err(const Range& out,
...
@@ -114,20 +129,22 @@ check_err(const Range& out,
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
correct_results_
=
correct_results_
&&
res
;
return
res
;
return
res
;
}
}
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
bool
>::
type
bool
>::
type
check_err
(
const
Range
&
out
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
double
atol
=
1e-3
)
{
{
found_supporting_instance_
=
true
;
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
...
@@ -160,24 +177,26 @@ check_err(const Range& out,
...
@@ -160,24 +177,26 @@ check_err(const Range& out,
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
correct_results_
=
correct_results_
&&
res
;
return
res
;
return
res
;
}
}
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
)
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
#endif
#endif
,
,
bool
>
bool
>
check_err
(
const
Range
&
out
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
=
0
,
double
=
0
,
double
atol
=
0
)
double
atol
=
0
)
{
{
found_supporting_instance_
=
true
;
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
...
@@ -211,7 +230,13 @@ check_err(const Range& out,
...
@@ -211,7 +230,13 @@ check_err(const Range& out,
{
{
std
::
cerr
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
cerr
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
correct_results_
=
correct_results_
&&
res
;
return
res
;
return
res
;
}
private:
bool
pass_if_no_instance_
;
bool
found_supporting_instance_
;
bool
correct_results_
;
}
}
}
// namespace utils
}
// namespace utils
...
...
profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp
View file @
ce87bcc7
...
@@ -76,7 +76,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
...
@@ -76,7 +76,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
using
RefAcc0DataType
=
float
;
using
RefAcc0DataType
=
float
;
using
RefAcc1DataType
=
float
;
using
RefAcc1DataType
=
float
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
const
int
DefaultStrideA0
=
ck
::
is_same_v
<
A0Layout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideA0
=
ck
::
is_same_v
<
A0Layout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
...
@@ -331,7 +331,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
...
@@ -331,7 +331,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
{
{
e1_g_m_o_device_buf
.
FromDevice
(
e1_g_m_o_device_result
.
mData
.
data
());
e1_g_m_o_device_buf
.
FromDevice
(
e1_g_m_o_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
e1_g_m_o_device_result
,
e1_g_m_o_host_result
);
validator
.
check_err
(
e1_g_m_o_device_result
,
e1_g_m_o_host_result
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -353,7 +353,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
...
@@ -353,7 +353,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
View file @
ce87bcc7
...
@@ -83,7 +83,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
...
@@ -83,7 +83,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
// A layout [G0, M, G1, K]
// A layout [G0, M, G1, K]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
...
@@ -355,7 +355,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
...
@@ -355,7 +355,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
atol
=
1e-2
;
atol
=
1e-2
;
}
}
pass
=
pass
&
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
,
validator
.
check_err
(
c_gs_ms_os_device_result
,
c_gs_ms_os_host_result
,
c_gs_ms_os_host_result
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol
,
rtol
,
...
@@ -388,7 +388,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
...
@@ -388,7 +388,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
View file @
ce87bcc7
...
@@ -78,7 +78,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
...
@@ -78,7 +78,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
...
@@ -284,7 +284,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
...
@@ -284,7 +284,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
{
{
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_o_device_result
,
c_g_m_o_host_result
);
validator
.
check_err
(
c_g_m_o_device_result
,
c_g_m_o_host_result
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -312,7 +312,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
...
@@ -312,7 +312,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batched_gemm_impl.hpp
View file @
ce87bcc7
...
@@ -49,7 +49,7 @@ bool profile_batched_gemm_impl(int do_verification,
...
@@ -49,7 +49,7 @@ bool profile_batched_gemm_impl(int do_verification,
int
StrideC
,
int
StrideC
,
int
BatchCount
)
int
BatchCount
)
{
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
row
,
...
@@ -234,7 +234,7 @@ bool profile_batched_gemm_impl(int do_verification,
...
@@ -234,7 +234,7 @@ bool profile_batched_gemm_impl(int do_verification,
{
{
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
validator
.
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -257,7 +257,7 @@ bool profile_batched_gemm_impl(int do_verification,
...
@@ -257,7 +257,7 @@ bool profile_batched_gemm_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp
View file @
ce87bcc7
...
@@ -72,7 +72,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
...
@@ -72,7 +72,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
int
StrideC
,
int
StrideC
,
int
BatchCount
)
int
BatchCount
)
{
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
row
,
...
@@ -316,13 +316,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
...
@@ -316,13 +316,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf
.
FromDevice
(
d0_g_m_device_result
.
mData
.
data
());
reduce0_device_buf
.
FromDevice
(
d0_g_m_device_result
.
mData
.
data
());
reduce1_device_buf
.
FromDevice
(
d1_g_m_device_result
.
mData
.
data
());
reduce1_device_buf
.
FromDevice
(
d1_g_m_device_result
.
mData
.
data
());
bool
c_error
=
ck
::
utils
::
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
validator
.
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
bool
d0_error
=
ck
::
utils
::
check_err
(
d0_g_m_device_result
,
d0_g_m_host_result
);
validator
.
check_err
(
d0_g_m_device_result
,
d0_g_m_host_result
);
bool
d1_error
=
ck
::
utils
::
check_err
(
d1_g_m_device_result
,
d1_g_m_host_result
);
validator
.
check_err
(
d1_g_m_device_result
,
d1_g_m_host_result
);
pass
=
pass
&&
(
c_error
==
true
);
pass
=
pass
&&
(
d0_error
==
true
);
pass
=
pass
&&
(
d1_error
==
true
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -355,7 +351,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
...
@@ -355,7 +351,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_gemm_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
View file @
ce87bcc7
...
@@ -86,7 +86,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
...
@@ -86,7 +86,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
...
@@ -312,7 +312,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
...
@@ -312,7 +312,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
{
{
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_o_device_result
,
c_g_m_o_host_result
);
validator
.
check_err
(
c_g_m_o_device_result
,
c_g_m_o_host_result
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -340,7 +340,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
...
@@ -340,7 +340,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
View file @
ce87bcc7
...
@@ -81,7 +81,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
...
@@ -81,7 +81,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
// A layout [G0, M, G1, K]
// A layout [G0, M, G1, K]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
...
@@ -327,7 +327,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
...
@@ -327,7 +327,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
atol
=
1e-2
;
atol
=
1e-2
;
}
}
pass
=
pass
&
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
,
validator
.
check_err
(
c_gs_ms_os_device_result
,
c_gs_ms_os_host_result
,
c_gs_ms_os_host_result
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol
,
rtol
,
...
@@ -360,7 +360,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
...
@@ -360,7 +360,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batchnorm_backward_impl.hpp
View file @
ce87bcc7
...
@@ -265,7 +265,7 @@ bool profile_batchnorm_backward_impl(bool do_verification,
...
@@ -265,7 +265,7 @@ bool profile_batchnorm_backward_impl(bool do_verification,
}
}
int
num_kernel
=
0
;
int
num_kernel
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
inst_ptr
:
instance_ptrs
)
for
(
auto
&
inst_ptr
:
instance_ptrs
)
{
{
...
@@ -340,20 +340,17 @@ bool profile_batchnorm_backward_impl(bool do_verification,
...
@@ -340,20 +340,17 @@ bool profile_batchnorm_backward_impl(bool do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
using
ck
::
utils
::
check_err
;
using
ck
::
utils
;
bool
single_pass
=
true
;
dx_dev
.
FromDevice
(
dx
.
mData
.
data
());
dx_dev
.
FromDevice
(
dx
.
mData
.
data
());
dscale_dev
.
FromDevice
(
dscale
.
data
());
dscale_dev
.
FromDevice
(
dscale
.
data
());
dbias_dev
.
FromDevice
(
dbias
.
data
());
dbias_dev
.
FromDevice
(
dbias
.
data
());
// clang-format off
// clang-format off
single_pass
=
single_pass
&&
ck
::
utils
::
check_err
(
dx
.
mData
,
dx_ref
.
mData
,
"dx result:"
,
5e-4
,
5e-4
);
validator
.
check_err
(
dx
.
mData
,
dx_ref
.
mData
,
"dx result:"
,
5e-4
,
5e-4
);
single_pass
=
single_pass
&&
ck
::
utils
::
check_err
(
dscale
.
mData
,
dscale_ref
.
mData
,
"dScale result:"
,
3e-3
,
3e-3
);
validator
.
check_err
(
dscale
.
mData
,
dscale_ref
.
mData
,
"dScale result:"
,
3e-3
,
3e-3
);
single_pass
=
single_pass
&&
ck
::
utils
::
check_err
(
dbias
.
mData
,
dbias_ref
.
mData
,
"dBias result:"
,
3e-3
,
3e-3
);
validator
.
check_err
(
dbias
.
mData
,
dbias_ref
.
mData
,
"dBias result:"
,
3e-3
,
3e-3
);
// clang-format on
// clang-format on
pass
=
pass
&&
single_pass
;
};
};
if
(
do_dumpout
)
if
(
do_dumpout
)
...
@@ -383,7 +380,7 @@ bool profile_batchnorm_backward_impl(bool do_verification,
...
@@ -383,7 +380,7 @@ bool profile_batchnorm_backward_impl(bool do_verification,
return
false
;
return
false
;
}
}
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batchnorm_forward_impl.hpp
View file @
ce87bcc7
...
@@ -259,7 +259,7 @@ bool profile_batchnorm_forward_impl(int do_verification,
...
@@ -259,7 +259,7 @@ bool profile_batchnorm_forward_impl(int do_verification,
}
}
int
num_kernel
=
0
;
int
num_kernel
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
inst_ptr
:
instance_ptrs
)
for
(
auto
&
inst_ptr
:
instance_ptrs
)
{
{
...
@@ -336,15 +336,15 @@ bool profile_batchnorm_forward_impl(int do_verification,
...
@@ -336,15 +336,15 @@ bool profile_batchnorm_forward_impl(int do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
using
ck
::
utils
::
check_err
;
using
ck
::
utils
;
bool
single_pass
;
bool
single_pass
;
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
if
constexpr
(
ck
::
is_same_v
<
YDataType
,
ck
::
bhalf_t
>
)
if
constexpr
(
ck
::
is_same_v
<
YDataType
,
ck
::
bhalf_t
>
)
single_pass
=
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
1e-2
,
1e-2
);
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
1e-2
,
1e-2
);
else
else
single_pass
=
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
4e-3
,
4e-3
);
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
4e-3
,
4e-3
);
if
(
updateMovingAverage
)
if
(
updateMovingAverage
)
{
{
...
@@ -352,8 +352,8 @@ bool profile_batchnorm_forward_impl(int do_verification,
...
@@ -352,8 +352,8 @@ bool profile_batchnorm_forward_impl(int do_verification,
resultRunningVariance_dev
.
FromDevice
(
resultRunningVariance
.
mData
.
data
());
resultRunningVariance_dev
.
FromDevice
(
resultRunningVariance
.
mData
.
data
());
// clang-format off
// clang-format off
single_pass
=
single_pass
&&
check_err
(
resultRunningMean
.
mData
,
resultRunningMean_ref
.
mData
,
"average mean results"
,
1.5e-5
,
1.5e-5
);
check_err
(
resultRunningMean
.
mData
,
resultRunningMean_ref
.
mData
,
"average mean results"
,
1.5e-5
,
1.5e-5
);
single_pass
=
single_pass
&&
check_err
(
resultRunningVariance
.
mData
,
resultRunningVariance_ref
.
mData
,
"average variance results"
,
1e-5
,
1e-5
);
check_err
(
resultRunningVariance
.
mData
,
resultRunningVariance_ref
.
mData
,
"average variance results"
,
1e-5
,
1e-5
);
// clang-format on
// clang-format on
};
};
...
@@ -363,12 +363,11 @@ bool profile_batchnorm_forward_impl(int do_verification,
...
@@ -363,12 +363,11 @@ bool profile_batchnorm_forward_impl(int do_verification,
resultSaveInvVariance_dev
.
FromDevice
(
resultSaveInvVariance
.
mData
.
data
());
resultSaveInvVariance_dev
.
FromDevice
(
resultSaveInvVariance
.
mData
.
data
());
// clang-format off
// clang-format off
single_pass
=
single_pass
&&
check_err
(
resultSaveMean
.
mData
,
resultSaveMean_ref
.
mData
,
"mean results"
,
3e-5
,
3e-5
);
check_err
(
resultSaveMean
.
mData
,
resultSaveMean_ref
.
mData
,
"mean results"
,
3e-5
,
3e-5
);
single_pass
=
single_pass
&&
check_err
(
resultSaveInvVariance
.
mData
,
resultSaveInvVariance_ref
.
mData
,
"inv-variance results"
,
7e-5
,
7e-5
);
check_err
(
resultSaveInvVariance
.
mData
,
resultSaveInvVariance_ref
.
mData
,
"inv-variance results"
,
7e-5
,
7e-5
);
// clang-format on
// clang-format on
};
};
pass
=
pass
&&
single_pass
;
};
};
if
(
do_dumpout
)
if
(
do_dumpout
)
...
@@ -405,7 +404,7 @@ bool profile_batchnorm_forward_impl(int do_verification,
...
@@ -405,7 +404,7 @@ bool profile_batchnorm_forward_impl(int do_verification,
return
false
;
return
false
;
}
}
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_batchnorm_infer_impl.hpp
View file @
ce87bcc7
...
@@ -231,7 +231,7 @@ bool profile_batchnorm_infer_impl(int do_verification,
...
@@ -231,7 +231,7 @@ bool profile_batchnorm_infer_impl(int do_verification,
}
}
int
num_kernel
=
0
;
int
num_kernel
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
inst_ptr
:
instance_ptrs
)
for
(
auto
&
inst_ptr
:
instance_ptrs
)
{
{
...
@@ -291,17 +291,15 @@ bool profile_batchnorm_infer_impl(int do_verification,
...
@@ -291,17 +291,15 @@ bool profile_batchnorm_infer_impl(int do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
using
ck
::
utils
::
check_err
;
using
ck
::
utils
;
bool
single_pass
;
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
if
constexpr
(
ck
::
is_same_v
<
YDataType
,
ck
::
bhalf_t
>
)
if
constexpr
(
ck
::
is_same_v
<
YDataType
,
ck
::
bhalf_t
>
)
single_pass
=
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
1e-2
,
1e-2
);
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
1e-2
,
1e-2
);
else
else
single_pass
=
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
4e-3
,
4e-3
);
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
4e-3
,
4e-3
);
pass
=
pass
&&
single_pass
;
};
};
if
(
do_dumpout
)
if
(
do_dumpout
)
...
@@ -328,7 +326,7 @@ bool profile_batchnorm_infer_impl(int do_verification,
...
@@ -328,7 +326,7 @@ bool profile_batchnorm_infer_impl(int do_verification,
return
false
;
return
false
;
}
}
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_contraction_impl.hpp
View file @
ce87bcc7
...
@@ -50,7 +50,7 @@ int profile_contraction_impl(ck::index_t do_verification,
...
@@ -50,7 +50,7 @@ int profile_contraction_impl(ck::index_t do_verification,
const
std
::
vector
<
ck
::
index_t
>&
StridesE
,
const
std
::
vector
<
ck
::
index_t
>&
StridesE
,
const
std
::
vector
<
ck
::
index_t
>&
StridesD
)
const
std
::
vector
<
ck
::
index_t
>&
StridesD
)
{
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
auto
f_host_tensor_descriptor
=
[](
const
std
::
vector
<
ck
::
index_t
>&
dims01
,
auto
f_host_tensor_descriptor
=
[](
const
std
::
vector
<
ck
::
index_t
>&
dims01
,
const
std
::
vector
<
ck
::
index_t
>&
dims23
,
const
std
::
vector
<
ck
::
index_t
>&
dims23
,
...
@@ -274,7 +274,7 @@ int profile_contraction_impl(ck::index_t do_verification,
...
@@ -274,7 +274,7 @@ int profile_contraction_impl(ck::index_t do_verification,
float
threshold
=
float
threshold
=
static_cast
<
DataType
>
(
nelems_k
)
*
std
::
numeric_limits
<
DataType
>::
epsilon
();
static_cast
<
DataType
>
(
nelems_k
)
*
std
::
numeric_limits
<
DataType
>::
epsilon
();
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
validator
.
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
e_m_n_host_result
,
"Error: incorrect results!"
,
"Error: incorrect results!"
,
threshold
,
threshold
,
...
@@ -338,7 +338,7 @@ int profile_contraction_impl(ck::index_t do_verification,
...
@@ -338,7 +338,7 @@ int profile_contraction_impl(ck::index_t do_verification,
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_conv_bwd_data_impl.hpp
View file @
ce87bcc7
...
@@ -153,7 +153,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
...
@@ -153,7 +153,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
// profile device Conv instances
// profile device Conv instances
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
{
{
...
@@ -209,7 +209,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
...
@@ -209,7 +209,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
{
{
in_device_buf
.
FromDevice
(
input_device_result
.
mData
.
data
());
in_device_buf
.
FromDevice
(
input_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
input_device_result
,
input_host_result
);
validator
.
check_err
(
input_device_result
,
input_host_result
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -241,7 +241,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
...
@@ -241,7 +241,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
<<
"
\n
name: "
<<
best_op_name
<<
"
\n
avg_time: "
<<
best_avg_time
<<
"
\n
name: "
<<
best_op_name
<<
"
\n
avg_time: "
<<
best_avg_time
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
std
::
endl
;
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp
View file @
ce87bcc7
...
@@ -193,6 +193,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
...
@@ -193,6 +193,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
float
best_tflops
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device Conv instances
// profile device Conv instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
{
{
...
@@ -251,7 +253,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
...
@@ -251,7 +253,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
{
{
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
ck
::
utils
::
check_err
(
out_n_k_ho_wo_device_result
,
out_n_k_ho_wo_host_result
);
validator
.
check_err
(
out_n_k_ho_wo_device_result
,
out_n_k_ho_wo_host_result
);
validator
.
is_success
();
if
(
do_log
)
if
(
do_log
)
{
{
...
...
profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp
View file @
ce87bcc7
...
@@ -183,6 +183,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
...
@@ -183,6 +183,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
float
best_tflops
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device Conv instances
// profile device Conv instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
{
{
...
@@ -239,7 +241,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
...
@@ -239,7 +241,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
{
{
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
ck
::
utils
::
check_err
(
out_n_k_ho_wo_device_result
,
out_n_k_ho_wo_host_result
);
validator
.
check_err
(
out_n_k_ho_wo_device_result
,
out_n_k_ho_wo_host_result
);
validator
.
is_success
();
if
(
do_log
)
if
(
do_log
)
{
{
...
...
profiler/include/profiler/profile_conv_fwd_impl.hpp
View file @
ce87bcc7
...
@@ -135,7 +135,7 @@ bool profile_conv_fwd_impl(int do_verification,
...
@@ -135,7 +135,7 @@ bool profile_conv_fwd_impl(int do_verification,
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
// profile device op instances
// profile device op instances
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
{
{
...
@@ -191,7 +191,7 @@ bool profile_conv_fwd_impl(int do_verification,
...
@@ -191,7 +191,7 @@ bool profile_conv_fwd_impl(int do_verification,
{
{
out_device_buf
.
FromDevice
(
device_output
.
mData
.
data
());
out_device_buf
.
FromDevice
(
device_output
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
device_output
,
host_output
);
validator
.
check_err
(
device_output
,
host_output
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -214,7 +214,7 @@ bool profile_conv_fwd_impl(int do_verification,
...
@@ -214,7 +214,7 @@ bool profile_conv_fwd_impl(int do_verification,
<<
"
\n
name: "
<<
best_op_name
<<
"
\n
avg_time: "
<<
best_avg_time
<<
"
\n
name: "
<<
best_op_name
<<
"
\n
avg_time: "
<<
best_avg_time
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
std
::
endl
;
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_elementwise_layernorm_impl.hpp
View file @
ce87bcc7
...
@@ -164,6 +164,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
...
@@ -164,6 +164,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
}
}
int
num_kernel
=
0
;
int
num_kernel
=
0
;
ck
::
utils
::
CorrectnessValidator
validator
;
for
(
auto
&
inst_ptr
:
instance_ptrs
)
for
(
auto
&
inst_ptr
:
instance_ptrs
)
{
{
...
@@ -221,8 +222,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
...
@@ -221,8 +222,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
{
{
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
bool
pass
=
validator
.
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -232,7 +232,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
...
@@ -232,7 +232,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
LogRangeAsType
<
float
>
(
std
::
cout
<<
"y : "
,
y
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"y : "
,
y
.
mData
,
","
)
<<
std
::
endl
;
}
}
if
(
!
pass
)
if
(
!
validator
.
is_success
()
)
{
{
std
::
cout
<<
inst_ptr
->
GetTypeString
()
<<
" failed verification: "
;
std
::
cout
<<
inst_ptr
->
GetTypeString
()
<<
" failed verification: "
;
LogRange
(
std
::
cout
<<
"lengths = ["
,
length
,
", "
)
<<
"]."
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"lengths = ["
,
length
,
", "
)
<<
"]."
<<
std
::
endl
;
...
@@ -253,13 +253,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
...
@@ -253,13 +253,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_instance_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_instance_name
<<
std
::
endl
;
}
}
if
(
num_kernel
==
0
)
return
validator
.
is_success
();
{
std
::
cout
<<
"Error: No kernel is tested"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp
View file @
ce87bcc7
...
@@ -165,7 +165,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
...
@@ -165,7 +165,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
float
best_tflops
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device operation instances
// profile device operation instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
...
@@ -223,7 +223,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
...
@@ -223,7 +223,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
{
{
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
validator
.
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
}
}
}
}
else
else
...
@@ -235,7 +235,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
...
@@ -235,7 +235,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp
View file @
ce87bcc7
...
@@ -156,7 +156,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
...
@@ -156,7 +156,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
float
best_tflops
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device operation instances
// profile device operation instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
...
@@ -213,7 +213,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
...
@@ -213,7 +213,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
{
{
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
validator
.
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
}
}
}
}
else
else
...
@@ -225,7 +225,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
...
@@ -225,7 +225,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
profiler/include/profiler/profile_gemm_add_multiply_impl.hpp
View file @
ce87bcc7
...
@@ -165,7 +165,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
...
@@ -165,7 +165,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
float
best_tflops
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
// profile device operation instances
// profile device operation instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
...
@@ -222,8 +222,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
...
@@ -222,8 +222,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
validator
.
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
}
}
}
}
else
else
...
@@ -235,7 +234,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
...
@@ -235,7 +234,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
return
validator
.
is_success
()
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment