Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
894f8bf5
"...resnet50_tensorflow.git" did not exist on "c5ad244e8e1ac263e8dd33533e9c2c0b53763d46"
Commit
894f8bf5
authored
Oct 11, 2023
by
Rostyslav Geyyer
Browse files
Update profiler api
parent
b73919f4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
22 deletions
+47
-22
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+6
-3
profiler/src/profile_gemm_splitk.cpp
profiler/src/profile_gemm_splitk.cpp
+41
-19
No files found.
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
894f8bf5
...
...
@@ -30,7 +30,8 @@ template <typename ADataType,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
,
typename
ComputeType
=
CDataType
>
bool
profile_gemm_splitk_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
...
...
@@ -103,7 +104,8 @@ bool profile_gemm_splitk_impl(int do_verification,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
,
ComputeType
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
@@ -120,7 +122,8 @@ bool profile_gemm_splitk_impl(int do_verification,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
,
ComputeType
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
profiler/src/profile_gemm_splitk.cpp
View file @
894f8bf5
...
...
@@ -25,6 +25,7 @@ enum struct GemmDataType
INT8_INT8_INT8
,
// 3
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
};
#define OP_NAME "gemm_splitk"
...
...
@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[])
if
(
argc
!=
15
)
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
...
@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[])
auto
c_type
,
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
auto
c_layout
,
auto
compute_type
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
AccDataType
=
decltype
(
acc_type
);
...
...
@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[])
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
using
ComputeType
=
decltype
(
compute_type
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
...
...
@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[])
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
CLayout
,
ComputeType
>
(
do_verification
,
init_method
,
do_log
,
...
...
@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[])
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Row
{},
Row
{}
,
F32
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Col
{},
Row
{}
,
F32
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Row
{},
Row
{}
,
F32
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Col
{},
Row
{}
,
F32
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{}
,
F16
{}
);
}
#if defined CK_ENABLE_FP8
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F8
{});
}
#endif
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment