Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
24771ab7
Commit
24771ab7
authored
Oct 24, 2024
by
Andriy Roshchenko
Browse files
Optionaly run either CPU or GPU verifications with GEMM examples.
parent
02958ba5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
11 deletions
+15
-11
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+8
-7
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+4
-1
example/01_gemm/run_gemm_example_streamk_v2.inc
example/01_gemm/run_gemm_example_streamk_v2.inc
+1
-1
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+1
-1
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
...library/reference_tensor_operation/gpu/reference_gemm.hpp
+1
-1
No files found.
example/01_gemm/common.hpp
View file @
24771ab7
...
@@ -75,9 +75,10 @@ struct ProblemSizeSplitK final
...
@@ -75,9 +75,10 @@ struct ProblemSizeSplitK final
struct
ExecutionConfig
final
struct
ExecutionConfig
final
{
{
bool
do_verification
=
true
;
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
int
init_method
=
2
;
int
do_verification
=
3
;
bool
time_kernel
=
false
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
};
};
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
...
@@ -126,7 +127,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
...
@@ -126,7 +127,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
@@ -176,7 +177,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
...
@@ -176,7 +177,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
else
else
{
{
std
::
cerr
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
...
@@ -225,7 +226,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
...
@@ -225,7 +226,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
@@ -275,7 +276,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
...
@@ -275,7 +276,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
...
example/01_gemm/run_gemm_example.inc
View file @
24771ab7
...
@@ -337,7 +337,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -337,7 +337,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
// CPU verification
// CPU verification
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
@@ -368,7 +368,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -368,7 +368,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#endif
#endif
if
(
pass
)
if
(
pass
)
std
::
cout
<<
"Verification on CPU: PASS"
<<
std
::
endl
;
std
::
cout
<<
"Verification on CPU: PASS"
<<
std
::
endl
;
}
if
((
config
.
do_verification
==
2
)
||
(
config
.
do_verification
==
3
))
{
// GPU verification
// GPU verification
auto
ref_gemm_gpu
=
ReferenceGemmInstanceGPU
{};
auto
ref_gemm_gpu
=
ReferenceGemmInstanceGPU
{};
auto
ref_invoker_gpu
=
ref_gemm_gpu
.
MakeInvoker
();
auto
ref_invoker_gpu
=
ref_gemm_gpu
.
MakeInvoker
();
...
...
example/01_gemm/run_gemm_example_streamk_v2.inc
View file @
24771ab7
...
@@ -241,7 +241,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -241,7 +241,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
std
::
cout
<<
"Compute reference GEMM on CPU... "
;
std
::
cout
<<
"Compute reference GEMM on CPU... "
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
...
example/01_gemm/run_gemm_example_v2.inc
View file @
24771ab7
...
@@ -228,7 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -228,7 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
std
::
cout
<<
"Compute reference GEMM on CPU... "
;
std
::
cout
<<
"Compute reference GEMM on CPU... "
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
...
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
View file @
24771ab7
...
@@ -76,7 +76,7 @@ __global__ void
...
@@ -76,7 +76,7 @@ __global__ void
// apply b_element_op
// apply b_element_op
b_element_op
(
v_b
,
p_b_grid
[
element_idx_b
]);
b_element_op
(
v_b
,
p_b_grid
[
element_idx_b
]);
// multiply and accumulate
// multiply and accumulate
v_acc
+=
static_cas
t
<
AccDataType
>
(
v_a
)
*
static_cas
t
<
AccDataType
>
(
v_b
);
v_acc
+=
type_conver
t
<
AccDataType
>
(
v_a
)
*
type_conver
t
<
AccDataType
>
(
v_b
);
}
}
// apply c_element_op
// apply c_element_op
c_element_op
(
v_c
,
v_acc
);
c_element_op
(
v_c
,
v_acc
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment