Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b14a8f83
Commit
b14a8f83
authored
Oct 22, 2024
by
Rostyslav Geyyer
Browse files
Add verification option selection
parent
314d2dd7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
21 deletions
+25
-21
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+8
-7
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+15
-12
example/01_gemm/run_gemm_example_streamk_v2.inc
example/01_gemm/run_gemm_example_streamk_v2.inc
+1
-1
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+1
-1
No files found.
example/01_gemm/common.hpp
View file @
b14a8f83
...
@@ -75,9 +75,10 @@ struct ProblemSizeSplitK final
...
@@ -75,9 +75,10 @@ struct ProblemSizeSplitK final
struct
ExecutionConfig
final
struct
ExecutionConfig
final
{
{
bool
do_verification
=
true
;
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
int
init_method
=
2
;
int
do_verification
=
3
;
bool
time_kernel
=
false
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
};
};
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
...
@@ -126,7 +127,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
...
@@ -126,7 +127,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
@@ -176,7 +177,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
...
@@ -176,7 +177,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
else
else
{
{
std
::
cerr
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
...
@@ -225,7 +226,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
...
@@ -225,7 +226,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
@@ -275,7 +276,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
...
@@ -275,7 +276,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
...
example/01_gemm/run_gemm_example.inc
View file @
b14a8f83
...
@@ -330,7 +330,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -330,7 +330,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
// CPU verification
// CPU verification
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
@@ -353,13 +353,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -353,13 +353,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
!
ck
::
utils
::
check_err
(
c_m_n_device_result
,
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
get_atol
<
CDataType
>
());
#endif
#endif
}
if
((
config
.
do_verification
==
2
)
||
(
config
.
do_verification
==
3
))
{
// GPU verification
// GPU verification
auto
ref_gemm_gpu
=
ReferenceGemmInstanceGPU
{};
auto
ref_gemm_gpu
=
ReferenceGemmInstanceGPU
{};
auto
ref_invoker_gpu
=
ref_gemm_gpu
.
MakeInvoker
();
auto
ref_invoker_gpu
=
ref_gemm_gpu
.
MakeInvoker
();
...
@@ -381,14 +384,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -381,14 +384,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
c_m_n_device_ref_buf
.
FromDevice
(
c_m_n_device_ref_result
.
mData
.
data
());
c_m_n_device_ref_buf
.
FromDevice
(
c_m_n_device_ref_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
!
ck
::
utils
::
check_err
(
c_m_n_device_result
,
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_device_ref_result
,
c_m_n_device_ref_result
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
get_atol
<
CDataType
>
());
}
}
return
!
pass
;
return
pass
==
true
;
}
}
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
...
...
example/01_gemm/run_gemm_example_streamk_v2.inc
View file @
b14a8f83
...
@@ -241,7 +241,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -241,7 +241,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
example/01_gemm/run_gemm_example_v2.inc
View file @
b14a8f83
...
@@ -228,7 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -228,7 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment