Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d862fdf0
Commit
d862fdf0
authored
Dec 31, 2021
by
ltqin
Browse files
add desiredgridsize parameter to ckProfiler
parent
adc79bdd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
61 deletions
+33
-61
device_operation/include/device_gemm_xdl_instance.hpp
device_operation/include/device_gemm_xdl_instance.hpp
+0
-48
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+17
-5
profiler/profile_gemm.cpp
profiler/profile_gemm.cpp
+16
-8
No files found.
device_operation/include/device_gemm_xdl_instance.hpp
deleted
100644 → 0
View file @
adc79bdd
#ifndef DEVICE_GEMM_XDL_INSTANCE
#define DEVICE_GEMM_XDL_INSTANCE
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
template
<
>
void
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
profiler/include/profile_gemm_impl.hpp
View file @
d862fdf0
#pragma once
#pragma once
#include "device_gemm_instance.hpp"
#include "device_gemm_instance.hpp"
#include "device_gemm_xdl_splitk_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -93,7 +94,8 @@ void profile_gemm_impl(int do_verification,
...
@@ -93,7 +94,8 @@ void profile_gemm_impl(int do_verification,
int
K
,
int
K
,
int
StrideA
,
int
StrideA
,
int
StrideB
,
int
StrideB
,
int
StrideC
)
int
StrideC
,
int
DesiredGridSize
=
1
)
{
{
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
@@ -154,9 +156,18 @@ void profile_gemm_impl(int do_verification,
...
@@ -154,9 +156,18 @@ void profile_gemm_impl(int do_verification,
// add device GEMM instances
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmNoOpPtr
>
gemm_ptrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
if
(
DesiredGridSize
>
1
&&
is_same
<
ADataType
,
float
>::
value
)
add_device_gemm_instance
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
{
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ALayout
,
BLayout
,
CLayout
>
(
gemm_ptrs
);
}
else
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
gemm_ptrs
);
}
if
(
gemm_ptrs
.
size
()
<=
0
)
if
(
gemm_ptrs
.
size
()
<=
0
)
{
{
...
@@ -183,7 +194,8 @@ void profile_gemm_impl(int do_verification,
...
@@ -183,7 +194,8 @@ void profile_gemm_impl(int do_verification,
StrideC
,
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
DesiredGridSize
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
...
...
profiler/profile_gemm.cpp
View file @
d862fdf0
...
@@ -35,7 +35,7 @@ enum GemmDataType
...
@@ -35,7 +35,7 @@ enum GemmDataType
int
profile_gemm
(
int
argc
,
char
*
argv
[])
int
profile_gemm
(
int
argc
,
char
*
argv
[])
{
{
if
(
argc
!
=
14
)
if
(
!
(
argc
=
=
14
||
argc
==
15
)
)
{
{
printf
(
"arg1: tensor operation (gemm: GEMM)
\n
"
);
printf
(
"arg1: tensor operation (gemm: GEMM)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
...
@@ -48,6 +48,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -48,6 +48,7 @@ int profile_gemm(int argc, char* argv[])
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: run kernel # of times (>1)
\n
"
);
printf
(
"arg7: run kernel # of times (>1)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: desired grid size
\n
"
);
exit
(
1
);
exit
(
1
);
}
}
...
@@ -62,9 +63,12 @@ int profile_gemm(int argc, char* argv[])
...
@@ -62,9 +63,12 @@ int profile_gemm(int argc, char* argv[])
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
int
DesiredGridSize
=
1
;
if
(
argc
==
15
)
DesiredGridSize
=
std
::
stoi
(
argv
[
14
]);
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
...
@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[])
K
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
...
@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[])
K
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
...
@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[])
K
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
...
@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[])
K
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
}
}
else
else
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment