Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3cc57101
Commit
3cc57101
authored
Mar 06, 2022
by
Jing Zhang
Browse files
wrap desc into a struct
parent
698573a9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
226 deletions
+125
-226
composable_kernel/include/config.hpp
composable_kernel/include/config.hpp
+0
-1
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
...de/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
+23
-32
device_operation/include/device_grouped_gemm_xdl.hpp
device_operation/include/device_grouped_gemm_xdl.hpp
+97
-100
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
+5
-93
No files found.
composable_kernel/include/config.hpp
View file @
3cc57101
...
...
@@ -176,7 +176,6 @@ struct gemm_desc
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
ck
::
index_t
OffsetA
,
OffsetB
,
OffsetC
;
ck
::
index_t
BlockStart
,
BlockSize
;
};
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
View file @
3cc57101
...
...
@@ -15,14 +15,10 @@ namespace ck {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainK0BlockLoop
,
index_t
MaxGroupCount
>
__global__
void
...
...
@@ -33,49 +29,43 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
StaticallyIndexedArray
<
AGridDesc_K0_M_K1
,
MaxGroupCount
>
a_grid_desc_k0_m_k1
,
const
StaticallyIndexedArray
<
BGridDesc_K0_N_K1
,
MaxGroupCount
>
b_grid_desc_k0_n_k1
,
const
StaticallyIndexedArray
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
MaxGroupCount
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_shapes
,
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
StaticallyIndexedArray
<
Block2CTileMap
,
MaxGroupCount
>
block_2_ctile_map
)
const
CElementwiseOperation
c_element_op
)
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_shapes
[
i
].
BlockStart
&&
block_id
<
(
gemm_shapes
[
i
].
BlockStart
+
gemm_shapes
[
i
].
BlockSize
))
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
)
{
const
index_t
group_id
=
i
;
const
index_t
block_id_grp
=
block_id
-
gemm_
shapes
[
i
].
BlockStart
;
const
index_t
a_offset_grp
=
gemm_
shapes
[
i
].
OffsetA
;
const
index_t
b_offset_grp
=
gemm_
shapes
[
i
].
OffsetB
;
const
index_t
c_offset_grp
=
gemm_
shapes
[
i
].
OffsetC
;
const
index_t
block_id_grp
=
block_id
-
gemm_
desc_
[
i
].
BlockStart
;
const
index_t
a_offset_grp
=
gemm_
desc_
[
i
].
OffsetA
;
const
index_t
b_offset_grp
=
gemm_
desc_
[
i
].
OffsetB
;
const
index_t
c_offset_grp
=
gemm_
desc_
[
i
].
OffsetC
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
p_shared
,
a_grid_desc_k0_m_k1
[
i
]
,
b_grid_desc_k0_n_k1
[
i
]
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
[
i
]
,
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1
_
,
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1
_
,
gemm_desc_
[
i
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
_
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
[
i
]
,
gemm_desc_
[
i
].
block_2_ctile_map
_
,
block_id_grp
);
return
;
}
});
}
#if 0
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
...
...
@@ -159,6 +149,7 @@ __global__ void
block_2_ctile_map_[group_id],
block_id_grp);
}
#endif
template
<
index_t
BlockSize
,
typename
FloatAB
,
...
...
device_operation/include/device_grouped_gemm_xdl.hpp
View file @
3cc57101
...
...
@@ -223,6 +223,20 @@ struct DeviceGroupedGemmXdl
CThreadTransferDstScalarPerVector
,
NumPrefetch
>
;
struct
GemmDesc
{
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
ck
::
index_t
OffsetA
,
OffsetB
,
OffsetC
;
ck
::
index_t
BlockStart
,
BlockEnd
;
};
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -246,6 +260,8 @@ struct DeviceGroupedGemmXdl
c_element_op_
{
c_element_op
}
{
grid_size
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
gemm_shapes_
.
size
())
{
...
...
@@ -257,27 +273,44 @@ struct DeviceGroupedGemmXdl
const
index_t
StrideB
=
gemm_shapes_
[
i
].
StrideB
;
const
index_t
StrideC
=
gemm_shapes_
[
i
].
StrideC
;
a_grid_desc_k0_m_k1_
(
i
)
=
gemm_desc_
(
i
).
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
(
i
)
=
gemm_desc_
(
i
).
b_grid_desc_k0_n_k1_
=
DeviceGroupedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
(
i
)
=
gemm_desc_
(
i
).
c_grid_desc_m_n_
=
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
[
i
],
b_grid_desc_k0_n_k1_
[
i
],
c_grid_desc_m_n_
[
i
],
const
index_t
grid_size_grp
=
GridwiseGemm
::
CalculateGridSize
(
gemm_desc_
[
i
].
c_grid_desc_m_n_
);
gemm_desc_
(
i
).
BlockStart
=
grid_size
;
gemm_desc_
(
i
).
BlockEnd
=
grid_size
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
gemm_desc_
(
i
).
OffsetA
=
gemm_shapes_
[
i
].
OffsetA
;
gemm_desc_
(
i
).
OffsetB
=
gemm_shapes_
[
i
].
OffsetB
;
gemm_desc_
(
i
).
OffsetC
=
gemm_shapes_
[
i
].
OffsetC
;
if
(
GridwiseGemm
::
CheckValidity
(
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_
[
i
].
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
(
i
)
=
gemm_desc_
(
i
).
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
[
i
]
);
gemm_desc_
[
i
].
c_grid_desc_m_n_
);
block_2_ctile_map_
(
i
)
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
[
i
]
,
M01
,
N01
);
gemm_desc_
(
i
).
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
gemm_desc_
[
i
].
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
else
{
gemm_desc_
(
i
).
BlockStart
=
-
1
;
gemm_desc_
(
i
).
BlockEnd
=
-
1
;
}
});
}
...
...
@@ -285,20 +318,16 @@ struct DeviceGroupedGemmXdl
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
StaticallyIndexedArray
<
AGridDesc_K0_M_K1
,
MaxGroupCount
>
a_grid_desc_k0_m_k1_
;
StaticallyIndexedArray
<
BGridDesc_K0_N_K1
,
MaxGroupCount
>
b_grid_desc_k0_n_k1_
;
StaticallyIndexedArray
<
CGridDesc_M_N
,
MaxGroupCount
>
c_grid_desc_m_n_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
MaxGroupCount
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
MaxGroupCount
>
block_2_ctile_map_
;
std
::
vector
<
gemm_desc
>
gemm_shapes_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
;
index_t
grid_size
;
};
// Invoker
...
...
@@ -308,134 +337,102 @@ struct DeviceGroupedGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
StaticallyIndexedArray
<
gemm_desc
,
MaxGroupCount
>
gemm_shapes
;
index_t
grid_size
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
arg
.
gemm_shapes_
.
size
())
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
[
i
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
[
i
],
arg
.
b_grid_desc_k0_n_k1_
[
i
],
arg
.
c_grid_desc_m_n_
[
i
],
<<
arg
.
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"Block: "
<<
arg
.
gemm_desc_
[
i
].
BlockStart
<<
", "
<<
arg
.
gemm_desc_
[
i
].
BlockEnd
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
gemm_desc_
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size_grp
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
[
i
]);
gemm_shapes
(
i
)
=
arg
.
gemm_shapes_
[
i
];
gemm_shapes
(
i
).
BlockStart
=
grid_size
;
gemm_shapes
(
i
).
BlockSize
=
grid_size_grp
;
grid_size
+=
grid_size_grp
;
std
::
cout
<<
"group_id "
<<
i
<<
" BlockStart "
<<
gemm_shapes
(
i
).
BlockStart
<<
" BlockSize "
<<
gemm_shapes
(
i
).
BlockSize
<<
std
::
endl
;
}
else
{
gemm_shapes
(
i
).
BlockStart
=
-
1
;
gemm_shapes
(
i
).
BlockSize
=
-
1
;
}
});
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}]
.
GetLength
(
I0
);
const
auto
K0
=
arg
.
gemm_desc_
[
I0
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
#if 1
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGroupedGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGroupedGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
gemm_desc
>
,
remove_reference_t
<
GemmDesc
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
,
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
arg
.
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
gemm_shapes
,
arg
.
gemm_desc_
,
arg
.
gemm_shapes_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
c_element_op_
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGroupedGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGroupedGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
gemm_desc
>
,
remove_reference_t
<
GemmDesc
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
,
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
arg
.
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
gemm_shapes
,
arg
.
gemm_desc_
,
arg
.
gemm_shapes_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
c_element_op_
);
}
#endif
return
ave_time
;
}
...
...
@@ -454,9 +451,9 @@ struct DeviceGroupedGemmXdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}]
,
arg
.
b_grid_desc_k0_n_k1_
[
Number
<
0
>
{}]
,
arg
.
c_grid
_desc_
m_n_
[
Number
<
0
>
{}],
return
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_
[
Number
<
0
>
{}].
a_grid_desc_k0_m_k1_
,
arg
.
gemm_desc_
[
Number
<
0
>
{}].
b_grid_desc_k0_n_k1_
,
arg
.
gemm
_desc_
[
Number
<
0
>
{}]
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
...
...
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
View file @
3cc57101
...
...
@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
exit
(
0
);
}
int
group_count
=
3
;
int
group_count
=
4
;
// GEMM shape
std
::
vector
<
ck
::
gemm_desc
>
gemm_shapes
;
...
...
@@ -85,11 +85,11 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
2048
+
256
*
i
;
int
N
=
2048
+
128
*
i
;
int
K
=
256
+
128
*
i
;
int
M
=
3840
;
int
N
=
1024
;
int
K
=
4096
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
A_size
,
B_size
,
C_size
,
-
1
,
-
1
});
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
A_size
,
B_size
,
C_size
});
A_size
+=
M
*
K
;
B_size
+=
N
*
K
;
...
...
@@ -236,94 +236,6 @@ int main(int argc, char* argv[])
check_error
(
c_host_tensors
[
i
],
c_device_tensors
[
i
]);
}
}
#if 0
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, nrepeat);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
check_error(c_m_n_host_result, c_m_n_device_result);
}
#endif
return
0
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment