Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d862fdf0
You need to sign in or sign up before continuing.
Commit
d862fdf0
authored
Dec 31, 2021
by
ltqin
Browse files
add desiredgridsize parameter to ckProfiler
parent
adc79bdd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
61 deletions
+33
-61
device_operation/include/device_gemm_xdl_instance.hpp
device_operation/include/device_gemm_xdl_instance.hpp
+0
-48
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+17
-5
profiler/profile_gemm.cpp
profiler/profile_gemm.cpp
+16
-8
No files found.
device_operation/include/device_gemm_xdl_instance.hpp
deleted
100644 → 0
View file @
adc79bdd
#ifndef DEVICE_GEMM_XDL_INSTANCE
#define DEVICE_GEMM_XDL_INSTANCE
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
template
<
>
void
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
profiler/include/profile_gemm_impl.hpp
View file @
d862fdf0
#pragma once
#pragma once
#include "device_gemm_instance.hpp"
#include "device_gemm_instance.hpp"
#include "device_gemm_xdl_splitk_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -93,7 +94,8 @@ void profile_gemm_impl(int do_verification,
...
@@ -93,7 +94,8 @@ void profile_gemm_impl(int do_verification,
int
K
,
int
K
,
int
StrideA
,
int
StrideA
,
int
StrideB
,
int
StrideB
,
int
StrideC
)
int
StrideC
,
int
DesiredGridSize
=
1
)
{
{
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
@@ -154,9 +156,18 @@ void profile_gemm_impl(int do_verification,
...
@@ -154,9 +156,18 @@ void profile_gemm_impl(int do_verification,
// add device GEMM instances
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmNoOpPtr
>
gemm_ptrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
if
(
DesiredGridSize
>
1
&&
is_same
<
ADataType
,
float
>::
value
)
add_device_gemm_instance
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
{
gemm_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_splitk_gemm_instance
<
float
,
float
,
float
,
ALayout
,
BLayout
,
CLayout
>
(
gemm_ptrs
);
}
else
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
gemm_ptrs
);
}
if
(
gemm_ptrs
.
size
()
<=
0
)
if
(
gemm_ptrs
.
size
()
<=
0
)
{
{
...
@@ -183,7 +194,8 @@ void profile_gemm_impl(int do_verification,
...
@@ -183,7 +194,8 @@ void profile_gemm_impl(int do_verification,
StrideC
,
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
DesiredGridSize
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
...
...
profiler/profile_gemm.cpp
View file @
d862fdf0
...
@@ -35,7 +35,7 @@ enum GemmDataType
...
@@ -35,7 +35,7 @@ enum GemmDataType
int
profile_gemm
(
int
argc
,
char
*
argv
[])
int
profile_gemm
(
int
argc
,
char
*
argv
[])
{
{
if
(
argc
!
=
14
)
if
(
!
(
argc
=
=
14
||
argc
==
15
)
)
{
{
printf
(
"arg1: tensor operation (gemm: GEMM)
\n
"
);
printf
(
"arg1: tensor operation (gemm: GEMM)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
...
@@ -48,6 +48,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -48,6 +48,7 @@ int profile_gemm(int argc, char* argv[])
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: run kernel # of times (>1)
\n
"
);
printf
(
"arg7: run kernel # of times (>1)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: desired grid size
\n
"
);
exit
(
1
);
exit
(
1
);
}
}
...
@@ -62,9 +63,12 @@ int profile_gemm(int argc, char* argv[])
...
@@ -62,9 +63,12 @@ int profile_gemm(int argc, char* argv[])
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
int
DesiredGridSize
=
1
;
if
(
argc
==
15
)
DesiredGridSize
=
std
::
stoi
(
argv
[
14
]);
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
...
@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[])
K
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
...
@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[])
K
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
...
@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[])
K
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
...
@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[])
K
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
(
StrideC
<
0
)
?
N
:
StrideC
,
DesiredGridSize
);
}
}
else
else
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment