Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
857010cc
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "0f1fca4dd68412b3117fe8bf120eae6c4a590dda"
Commit
857010cc
authored
Mar 14, 2022
by
Jing Zhang
Browse files
clean
parent
6af8e20b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
85 deletions
+85
-85
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+4
-8
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+53
-47
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
+17
-18
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+11
-7
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+0
-5
No files found.
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
857010cc
...
@@ -83,15 +83,11 @@ int main(int argc, char* argv[])
...
@@ -83,15 +83,11 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
{
//
int M = 256 + 256 * i;
int
M
=
256
+
256
*
i
;
//
int N = 128 + 128 * i;
int
N
=
128
+
128
*
i
;
//
int K = 64 + 64 * i;
int
K
=
64
+
64
*
i
;
int
M
=
3840
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
nullptr
,
nullptr
,
nullptr
});
int
N
=
1024
;
int
K
=
4096
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
N
,
N
,
nullptr
,
nullptr
,
nullptr
});
}
}
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
857010cc
...
@@ -223,7 +223,7 @@ struct DeviceGroupedGemmXdl
...
@@ -223,7 +223,7 @@ struct DeviceGroupedGemmXdl
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
NumPrefetch
>
;
NumPrefetch
>
;
struct
GemmDesc
struct
GemmDesc
KernelArg
{
{
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
...
@@ -231,11 +231,13 @@ struct DeviceGroupedGemmXdl
...
@@ -231,11 +231,13 @@ struct DeviceGroupedGemmXdl
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
const
ADataType
*
a_ptr
;
const
ADataType
*
a_ptr
;
const
BDataType
*
b_ptr
;
const
BDataType
*
b_ptr
;
CDataType
*
c_ptr
;
CDataType
*
c_ptr
;
ck
::
index_t
BlockStart
,
BlockEnd
;
ck
::
index_t
BlockStart
,
BlockEnd
;
};
};
...
@@ -254,10 +256,12 @@ struct DeviceGroupedGemmXdl
...
@@ -254,10 +256,12 @@ struct DeviceGroupedGemmXdl
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
}
{
{
grid_size
=
0
;
grid_size
=
0
;
group_count_
=
0
;
for
(
index_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
index_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
{
group_count_
++
;
const
index_t
M
=
gemm_shapes
[
i
].
M
;
const
index_t
M
=
gemm_shapes
[
i
].
M
;
const
index_t
N
=
gemm_shapes
[
i
].
N
;
const
index_t
N
=
gemm_shapes
[
i
].
N
;
const
index_t
K
=
gemm_shapes
[
i
].
K
;
const
index_t
K
=
gemm_shapes
[
i
].
K
;
...
@@ -289,16 +293,17 @@ struct DeviceGroupedGemmXdl
...
@@ -289,16 +293,17 @@ struct DeviceGroupedGemmXdl
const
auto
block_2_ctile_map_
=
const
auto
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GemmShape_
.
push_back
(
GemmDesc
{
a_grid_desc_k0_m_k1_
,
gemm_desc_kernel_arg_
.
push_back
(
b_grid_desc_k0_n_k1_
,
GemmDescKernelArg
{
a_grid_desc_k0_m_k1_
,
c_grid_desc_m_n_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
static_cast
<
const
ADataType
*>
(
gemm_shapes
[
i
].
p_a
),
block_2_ctile_map_
,
static_cast
<
const
BDataType
*>
(
gemm_shapes
[
i
].
p_b
),
static_cast
<
const
ADataType
*>
(
gemm_shapes
[
i
].
p_a
),
static_cast
<
CDataType
*>
(
gemm_shapes
[
i
].
p_c
),
static_cast
<
const
BDataType
*>
(
gemm_shapes
[
i
].
p_b
),
BlockStart
,
static_cast
<
CDataType
*>
(
gemm_shapes
[
i
].
p_c
),
BlockEnd
});
BlockStart
,
BlockEnd
});
}
}
}
}
}
}
...
@@ -306,11 +311,12 @@ struct DeviceGroupedGemmXdl
...
@@ -306,11 +311,12 @@ struct DeviceGroupedGemmXdl
// private:
// private:
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
index_t
group_count_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
std
::
vector
<
GemmDesc
>
G
emm
Shape
_
;
std
::
vector
<
GemmDesc
KernelArg
>
g
emm
_desc_kernel_arg
_
;
index_t
grid_size
;
index_t
grid_size
;
};
};
...
@@ -322,41 +328,48 @@ struct DeviceGroupedGemmXdl
...
@@ -322,41 +328,48 @@ struct DeviceGroupedGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
G
emm
Shape
_arg
;
StaticallyIndexedArray
<
GemmDesc
KernelArg
,
MaxGroupCount
>
g
emm
_desc_kernel_arg
_arg
;
bool
has_main_k0_block_loop
=
true
;
bool
has_main_k0_block_loop
=
true
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
arg
.
G
emm
Shape
_
.
size
())
if
(
i
<
arg
.
g
emm
_desc_kernel_arg
_
.
size
())
{
{
G
emm
Shape
_arg
(
i
)
=
arg
.
G
emm
Shape
_
[
i
];
g
emm
_desc_kernel_arg
_arg
(
i
)
=
arg
.
g
emm
_desc_kernel_arg
_
[
i
];
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
GemmShape_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
GemmShape_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
GemmShape_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
<<
G
emm
Shape
_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
g
emm
_desc_kernel_arg
_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
G
emm
Shape
_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
g
emm
_desc_kernel_arg
_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
GemmShape_arg
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
,
GemmShape_arg
[
i
].
c_grid_desc_m_n_
,
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
M01_
,
gemm_desc_kernel_arg_arg
[
i
].
c_grid_desc_m_n_
,
arg
.
N01_
))
arg
.
M01_
,
arg
.
N01_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
}
const
auto
K0
=
G
emm
Shape
_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K0
=
g
emm
_desc_kernel_arg
_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
if
(
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
)
!=
has_main_k0_block_loop
)
if
(
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
)
!=
has_main_k0_block_loop
)
{
{
...
@@ -373,7 +386,7 @@ struct DeviceGroupedGemmXdl
...
@@ -373,7 +386,7 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
GemmDesc
>
,
remove_reference_t
<
GemmDesc
KernelArg
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
...
@@ -385,8 +398,8 @@ struct DeviceGroupedGemmXdl
...
@@ -385,8 +398,8 @@ struct DeviceGroupedGemmXdl
dim3
(
arg
.
grid_size
),
dim3
(
arg
.
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
G
emm
Shape
_arg
,
g
emm
_desc_kernel_arg
_arg
,
arg
.
G
emm
Shape
_
.
size
(),
arg
.
g
emm
_desc_kernel_arg
_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
...
@@ -397,7 +410,7 @@ struct DeviceGroupedGemmXdl
...
@@ -397,7 +410,7 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
GemmDesc
>
,
remove_reference_t
<
GemmDesc
KernelArg
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
...
@@ -409,8 +422,8 @@ struct DeviceGroupedGemmXdl
...
@@ -409,8 +422,8 @@ struct DeviceGroupedGemmXdl
dim3
(
arg
.
grid_size
),
dim3
(
arg
.
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
G
emm
Shape
_arg
,
g
emm
_desc_kernel_arg
_arg
,
arg
.
G
emm
Shape
_
.
size
(),
arg
.
g
emm
_desc_kernel_arg
_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
...
@@ -434,17 +447,10 @@ struct DeviceGroupedGemmXdl
...
@@ -434,17 +447,10 @@ struct DeviceGroupedGemmXdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
bool
isValid
=
true
;
if
(
arg
.
gemm_desc_kernel_arg_
.
size
()
!=
arg
.
group_count_
)
for
(
int
i
=
0
;
i
<
arg
.
GemmShape_
.
size
();
i
++
)
return
false
;
{
else
return
true
;
isValid
&=
GridwiseGemm
::
CheckValidity
(
arg
.
GemmShape_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
GemmShape_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
GemmShape_
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
return
isValid
;
}
}
// polymorphic
// polymorphic
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
857010cc
...
@@ -29,24 +29,23 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
...
@@ -29,24 +29,23 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
256
,
4
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
32
,
32
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
8
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
128
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
32
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
857010cc
...
@@ -69,6 +69,12 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -69,6 +69,12 @@ void profile_grouped_gemm_impl(int do_verification,
}
}
};
};
if
(
!
(
Ms
.
size
()
==
Ns
.
size
()
&&
Ns
.
size
()
==
Ks
.
size
()
&&
Ks
.
size
()
==
StrideAs
.
size
()
&&
StrideAs
.
size
()
==
StrideBs
.
size
()
&&
StrideBs
.
size
()
==
StrideCs
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! inconsistent Ms, Ns, Ks, StrideA/B/Cs size
\n
"
);
}
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
...
@@ -83,11 +89,9 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -83,11 +89,9 @@ void profile_grouped_gemm_impl(int do_verification,
c_m_n_device_results
.
push_back
(
c_m_n_device_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
std
::
cout
<<
"a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
std
::
endl
;
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
std
::
cout
<<
"b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
std
::
endl
;
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
<<
"]:"
<<
c_m_n_device_results
[
i
].
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_device_results["
<<
i
<<
"]:"
<<
c_m_n_device_results
[
i
].
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
switch
(
init_method
)
...
@@ -129,6 +133,7 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -129,6 +133,7 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSize
()));
b_device_buf
.
push_back
(
b_device_buf
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSize
()));
c_device_buf
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
c_device_buf
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSize
()));
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSize
()));
...
@@ -254,10 +259,9 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -254,10 +259,9 @@ void profile_grouped_gemm_impl(int do_verification,
std
::
size_t
flop
=
0
,
num_btype
=
0
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
flop
+=
std
::
size_t
(
2
)
*
Ms
[
i
]
*
Ns
[
i
]
*
Ks
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
i
]
*
Ks
[
i
]
+
sizeof
(
BDataType
)
*
Ks
[
i
]
*
M
s
[
i
]
+
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
i
]
*
Ks
[
i
]
+
sizeof
(
BDataType
)
*
Ks
[
i
]
*
N
s
[
i
]
+
sizeof
(
CDataType
)
*
Ms
[
i
]
*
Ns
[
i
];
sizeof
(
CDataType
)
*
Ms
[
i
]
*
Ns
[
i
];
}
}
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
857010cc
...
@@ -75,11 +75,6 @@ int profile_grouped_gemm(int argc, char* argv[])
...
@@ -75,11 +75,6 @@ int profile_grouped_gemm(int argc, char* argv[])
const
auto
StrideBs
=
stringToArray
(
argv
[
12
]);
const
auto
StrideBs
=
stringToArray
(
argv
[
12
]);
const
auto
StrideCs
=
stringToArray
(
argv
[
13
]);
const
auto
StrideCs
=
stringToArray
(
argv
[
13
]);
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
std
::
cout
<<
"M: "
<<
Ms
[
i
]
<<
" N: "
<<
Ns
[
i
]
<<
" K: "
<<
Ks
[
i
]
<<
std
::
endl
;
}
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment