Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
99cc8431
Commit
99cc8431
authored
Aug 06, 2024
by
Jing Zhang
Browse files
format
parent
7d69eb3b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
71 deletions
+93
-71
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
.../include/profiler/profile_gemm_multiply_multiply_impl.hpp
+93
-71
No files found.
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
View file @
99cc8431
...
@@ -187,8 +187,21 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -187,8 +187,21 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
// profile device GEMM instances
// profile device GEMM instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
{
{
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
12
,
16
,
19
,
20
,
32
,
38
};
if
(
KBatch
>
0
)
{
kbatch_list
=
{
KBatch
};
}
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
const
void
*
,
2
>
{
d0_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
2
>
{
d0_device_buf
.
GetDeviceBuffer
(),
d1_device_buf
.
GetDeviceBuffer
()},
d1_device_buf
.
GetDeviceBuffer
()},
...
@@ -213,7 +226,8 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -213,7 +226,8 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
// re-init C to zero before profiling next kernel
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
c_device_buf
.
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
,
0
,
n_warmup
,
n_iter
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
,
0
,
n_warmup
,
n_iter
});
if
(
do_verification
)
if
(
do_verification
)
{
{
...
@@ -225,31 +239,37 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -225,31 +239,37 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
{
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
e_m_n_host_result
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
e_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
e_m_n_device_result
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
e_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
}
}
}
}
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
StreamConfig
{
time_kernel
,
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
,
rotating_count
>
1
,
rotating_count
});
0
,
n_warmup
,
n_iter
,
rotating_count
>
1
,
rotating_count
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
// set softer tolerances for fp8
...
@@ -280,7 +300,9 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -280,7 +300,9 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
}
}
else
else
{
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment