Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f99f614e
"vscode:/vscode.git/clone" did not exist on "73e475d8ca49347e202bd7593f619ea0242d5926"
Commit
f99f614e
authored
May 25, 2022
by
Chao Liu
Browse files
update profiler
parent
b238662a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
62 deletions
+57
-62
profiler/include/profile_gemm_gelu_impl.hpp
profiler/include/profile_gemm_gelu_impl.hpp
+57
-62
No files found.
profiler/include/profile_gemm_gelu_impl.hpp
View file @
f99f614e
...
@@ -75,6 +75,7 @@ int profile_gemm_gelu_impl(int do_verification,
...
@@ -75,6 +75,7 @@ int profile_gemm_gelu_impl(int do_verification,
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
...
@@ -101,16 +102,9 @@ int profile_gemm_gelu_impl(int do_verification,
...
@@ -101,16 +102,9 @@ int profile_gemm_gelu_impl(int do_verification,
const
auto
b_element_op
=
BElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device GEMM instances
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmGeluPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmGeluPtr
>
device_op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
)
...
@@ -120,48 +114,66 @@ int profile_gemm_gelu_impl(int do_verification,
...
@@ -120,48 +114,66 @@ int profile_gemm_gelu_impl(int do_verification,
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
gemm
_ptrs
);
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
device_op
_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
gemm
_ptrs
);
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
device_op
_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
gemm
_ptrs
);
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
device_op
_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
gemm
_ptrs
);
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
device_op
_ptrs
);
}
}
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
std
::
cout
<<
"found "
<<
device_op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// run reference
if
(
do_verification
)
{
{
throw
std
::
runtime_error
(
"wrong! no device operation instance found"
);
using
ReferenceOpInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_op
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_op
.
MakeInvoker
();
auto
ref_argument
=
ref_op
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
}
std
::
string
best_gemm_name
;
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
std
::
string
best_device_op_name
;
float
best_ave_time
=
0
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
bool
pass
=
true
;
bool
pass
=
true
;
// profile device
GEMM
instances
// profile device
operation
instances
for
(
auto
&
gemm_ptr
:
gemm
_ptrs
)
for
(
auto
&
device_op_ptr
:
device_op
_ptrs
)
{
{
auto
argument_ptr
=
auto
argument_ptr
=
device_op_ptr
->
MakeArgumentPointer
(
gemm_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
...
@@ -174,15 +186,15 @@ int profile_gemm_gelu_impl(int do_verification,
...
@@ -174,15 +186,15 @@ int profile_gemm_gelu_impl(int do_verification,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
device_op_ptr
->
MakeInvokerPointer
();
std
::
string
device_op_name
=
device_op_ptr
->
GetTypeString
();
if
(
gemm
_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
device_op
_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
// re-init C to zero before profiling
next
kernel
// re-init C to zero before profiling
a
kernel
c_device_buf
.
SetZero
();
c_device_buf
.
SetZero
();
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
float
ave_time
=
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
...
@@ -196,11 +208,11 @@ int profile_gemm_gelu_impl(int do_verification,
...
@@ -196,11 +208,11 @@ int profile_gemm_gelu_impl(int do_verification,
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
_name
<<
std
::
endl
;
<<
gb_per_sec
<<
" GB/s, "
<<
device_op
_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
if
(
tflops
>
best_tflops
)
{
{
best_
gemm
_name
=
gemm
_name
;
best_
device_op
_name
=
device_op
_name
;
best_tflops
=
tflops
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_gb_per_sec
=
gb_per_sec
;
...
@@ -210,23 +222,6 @@ int profile_gemm_gelu_impl(int do_verification,
...
@@ -210,23 +222,6 @@ int profile_gemm_gelu_impl(int do_verification,
{
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
pass
=
pass
&&
pass
=
pass
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
);
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
);
...
@@ -243,12 +238,12 @@ int profile_gemm_gelu_impl(int do_verification,
...
@@ -243,12 +238,12 @@ int profile_gemm_gelu_impl(int do_verification,
}
}
else
else
{
{
std
::
cout
<<
"
does not support this problem"
<<
std
::
endl
;
std
::
cout
<<
device_op_name
<<
"
does not support this problem"
<<
std
::
endl
;
}
}
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_
gemm
_name
<<
std
::
endl
;
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_
device_op
_name
<<
std
::
endl
;
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment