Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e3a4b967
"vscode:/vscode.git/clone" did not exist on "7967161abf66de3cf42eb44672d61e6d4a8ddf56"
Commit
e3a4b967
authored
Mar 12, 2022
by
Jing Zhang
Browse files
fixed mem issue with unique_ptr
parent
8fb2b172
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
268 additions
and
910 deletions
+268
-910
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
...de/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
+0
-698
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+28
-57
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+14
-9
include/ck/tensor_operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
..._operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
+6
-8
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+1
-0
library/src/host_tensor/device.cpp
library/src/host_tensor/device.cpp
+6
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
+6
-6
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+184
-113
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+22
-18
No files found.
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
deleted
100644 → 0
View file @
8fb2b172
#ifndef CK_GRIDWISE_GROUPED_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_GROUPED_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainK0BlockLoop
,
index_t
MaxGroupCount
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
)
{
const
index_t
group_id
=
i
;
const
index_t
block_id_grp
=
block_id
-
gemm_desc_
[
i
].
BlockStart
;
const
index_t
a_offset_grp
=
gemm_desc_
[
i
].
OffsetA
;
const
index_t
b_offset_grp
=
gemm_desc_
[
i
].
OffsetB
;
const
index_t
c_offset_grp
=
gemm_desc_
[
i
].
OffsetC
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
p_shared
,
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_
[
i
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_
[
i
].
block_2_ctile_map_
,
block_id_grp
);
}
});
}
#if 0
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r4(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const StaticallyIndexedArray<AGridDesc_K0_M_K1, MaxGroupCount> a_grid_desc_k0_m_k1,
const StaticallyIndexedArray<BGridDesc_K0_N_K1, MaxGroupCount> b_grid_desc_k0_n_k1,
const StaticallyIndexedArray<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, MaxGroupCount>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_shapes,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const StaticallyIndexedArray<Block2CTileMap, MaxGroupCount> block_2_ctile_map)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
__shared__ AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_[MaxGroupCount];
__shared__ BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_[MaxGroupCount];
__shared__ CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[MaxGroupCount];
__shared__ Block2CTileMap block_2_ctile_map_[MaxGroupCount];
__shared__ GemmDesc gemm_shapes_[MaxGroupCount];
if(get_thread_local_1d_id())
{
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
a_grid_desc_k0_m_k1_[i] = a_grid_desc_k0_m_k1[i];
b_grid_desc_k0_n_k1_[i] = b_grid_desc_k0_n_k1[i];
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[i] = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2[i];
block_2_ctile_map_[i] = block_2_ctile_map[i];
gemm_shapes_[i] = gemm_shapes[i];
});
}
block_sync_lds();
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_shapes[i].BlockStart &&
block_id < (gemm_shapes[i].BlockStart + gemm_shapes[i].BlockSize))
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_shapes_[group_id].BlockStart;
const index_t a_offset_grp = gemm_shapes_[group_id].OffsetA;
const index_t b_offset_grp = gemm_shapes_[group_id].OffsetB;
const index_t c_offset_grp = gemm_shapes_[group_id].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
p_shared,
a_grid_desc_k0_m_k1_[group_id],
b_grid_desc_k0_n_k1_[group_id],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_[group_id],
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map_[group_id],
block_id_grp);
}
#endif
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumPrefetch
=
1
>
struct
GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check NumPrefetch
if
constexpr
(
NumPrefetch
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
index_t
block_id
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_id
));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumPrefetch
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
CElementwiseOperation
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
]),
c_element_op
};
c_thread_copy
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
);
}
}
};
}
// namespace ck
#endif
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
e3a4b967
...
...
@@ -81,19 +81,17 @@ int main(int argc, char* argv[])
// GEMM shape
std
::
vector
<
ck
::
GemmShape
>
gemm_shapes
;
int
A_size
=
0
,
B_size
=
0
,
C_size
=
0
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
256
+
256
*
i
;
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
//
int M = 256 + 256 * i;
//
int N = 128 + 128 * i;
//
int K = 64 + 64 * i;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
nullptr
,
nullptr
,
nullptr
});
int
M
=
3840
;
int
N
=
1024
;
int
K
=
4096
;
A_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
;
B_size
+=
gemm_shapes
[
i
].
N
*
gemm_shapes
[
i
].
K
;
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
N
,
N
,
nullptr
,
nullptr
,
nullptr
});
}
auto
f_host_tensor_descriptor
=
...
...
@@ -115,6 +113,10 @@ int main(int argc, char* argv[])
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_device_tensors
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
c_tensors_device
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
...
...
@@ -133,13 +135,10 @@ int main(int argc, char* argv[])
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
*
gemm_shapes
[
i
].
N
;
num_btype
+=
sizeof
(
ADataType
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
+
sizeof
(
BDataType
)
*
gemm_shapes
[
i
].
K
*
gemm_shapes
[
i
].
N
+
sizeof
(
CDataType
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
}
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
switch
(
init_method
)
{
case
0
:
break
;
...
...
@@ -157,38 +156,23 @@ int main(int argc, char* argv[])
}
}
DeviceMem
a_tensors_device_buf
(
sizeof
(
ADataType
)
*
A_size
);
DeviceMem
b_tensors_device_buf
(
sizeof
(
BDataType
)
*
B_size
);
DeviceMem
c_tensors_device_buf
(
sizeof
(
CDataType
)
*
C_size
);
std
::
vector
<
ADataType
>
a_tensors_data
,
b_tensors_data
,
c_tensors_data
;
A_size
=
0
;
B_size
=
0
;
C_size
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
a_tensors_data
.
insert
(
a_tensors_data
.
end
(),
a_tensors
[
i
].
mData
.
begin
(),
a_tensors
[
i
].
mData
.
end
());
b_tensors_data
.
insert
(
b_tensors_data
.
end
(),
b_tensors
[
i
].
mData
.
begin
(),
b_tensors
[
i
].
mData
.
end
());
gemm_shapes
[
i
].
p_a
=
static_cast
<
ADataType
*>
(
a_tensors_device_buf
.
GetDeviceBuffer
())
+
A_size
;
gemm_shapes
[
i
].
p_b
=
static_cast
<
BDataType
*>
(
b_tensors_device_buf
.
GetDeviceBuffer
())
+
B_size
;
gemm_shapes
[
i
].
p_c
=
static_cast
<
CDataType
*>
(
c_tensors_device_buf
.
GetDeviceBuffer
())
+
C_size
;
A_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
;
B_size
+=
gemm_shapes
[
i
].
N
*
gemm_shapes
[
i
].
K
;
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
a_tensors_device
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()));
b_tensors_device
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()));
c_tensors_device
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
gemm_shapes
[
i
].
p_a
=
a_tensors_device
[
i
]
->
GetDeviceBuffer
();
gemm_shapes
[
i
].
p_b
=
b_tensors_device
[
i
]
->
GetDeviceBuffer
();
gemm_shapes
[
i
].
p_c
=
c_tensors_device
[
i
]
->
GetDeviceBuffer
();
}
a_tensors_device_buf
.
ToDevice
(
a_tensors_data
.
data
());
b_tensors_device_buf
.
ToDevice
(
b_tensors_data
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
...
...
@@ -214,24 +198,11 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
c_tensors_data
.
resize
(
C_size
);
c_tensors_device_buf
.
FromDevice
(
c_tensors_data
.
data
());
C_size
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
memcpy
(
c_device_tensors
[
i
].
mData
.
data
(),
c_tensors_data
.
data
()
+
C_size
,
c_device_tensors
[
i
].
mData
.
size
()
*
sizeof
(
CDataType
));
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
}
if
(
do_verification
)
{
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
e3a4b967
...
...
@@ -70,7 +70,7 @@ template <typename AElementwiseOperation,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
e3a4b967
...
...
@@ -242,7 +242,7 @@ struct DeviceGroupedGemmXdl
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
Argument
(
std
::
vector
<
GemmShape
>
&
gemm_shapes
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
...
...
@@ -360,8 +360,7 @@ struct DeviceGroupedGemmXdl
if
(
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
)
!=
has_main_k0_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k0_block_loop"
);
}
}
});
...
...
@@ -435,11 +434,17 @@ struct DeviceGroupedGemmXdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
GemmShape_
[
0
].
a_grid_desc_k0_m_k1_
,
arg
.
GemmShape_
[
0
].
b_grid_desc_k0_n_k1_
,
arg
.
GemmShape_
[
0
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
bool
isValid
=
true
;
for
(
int
i
=
0
;
i
<
arg
.
GemmShape_
.
size
();
i
++
)
{
isValid
&=
GridwiseGemm
::
CheckValidity
(
arg
.
GemmShape_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
GemmShape_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
GemmShape_
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
return
isValid
;
}
// polymorphic
...
...
@@ -459,7 +464,7 @@ struct DeviceGroupedGemmXdl
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
View file @
e3a4b967
...
...
@@ -60,24 +60,22 @@ __global__ void
}
});
#else
const
GemmDesc
*
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
)
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
i
<
group_count
)
?
i
:
group_id
;
});
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
const
index_t
a_offset_grp
=
gemm_desc_ptr
[
group_id
].
OffsetA
;
const
index_t
b_offset_grp
=
gemm_desc_ptr
[
group_id
].
OffsetB
;
const
index_t
c_offset_grp
=
gemm_desc_ptr
[
group_id
].
OffsetC
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
...
...
library/include/ck/library/host_tensor/device.hpp
View file @
e3a4b967
...
...
@@ -12,6 +12,7 @@ struct DeviceMem
{
DeviceMem
()
=
delete
;
DeviceMem
(
std
::
size_t
mem_size
);
DeviceMem
(
const
DeviceMem
&
p
);
void
*
GetDeviceBuffer
();
void
ToDevice
(
const
void
*
p
);
void
FromDevice
(
void
*
p
);
...
...
library/src/host_tensor/device.cpp
View file @
e3a4b967
...
...
@@ -5,6 +5,12 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
DeviceMem
::
DeviceMem
(
const
DeviceMem
&
p
)
:
mpDeviceBuf
(
p
.
mpDeviceBuf
),
mMemSize
(
p
.
mMemSize
)
{
// hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
// hipGetErrorString(hipMemcpy(mpDeviceBuf, p.mpDeviceBuf, mMemSize, hipMemcpyDeviceToDevice));
}
void
*
DeviceMem
::
GetDeviceBuffer
()
{
return
mpDeviceBuf
;
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
)
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
e3a4b967
...
...
@@ -23,9 +23,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization_t
::
Default
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
using
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
...
...
@@ -48,13 +47,14 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
{});
}
}
// namespace device_grouped_gemm_instance
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
e3a4b967
...
...
@@ -16,15 +16,19 @@ namespace tensor_operation {
namespace
device
{
namespace
device_grouped_gemm_instance
{
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
//void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
}
// namespace device_grouped_gemm_instance
}
// namespace device
...
...
@@ -41,15 +45,15 @@ template <typename ADataType,
typename
BLayout
,
typename
CLayout
>
void
profile_grouped_gemm_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
int
nrepeat
,
std
::
vector
<
int
>
Ms
,
std
::
vector
<
int
>
Ns
,
std
::
vector
<
int
>
Ks
,
std
::
vector
<
int
>
StrideAs
,
std
::
vector
<
int
>
StrideBs
,
std
::
vector
<
int
>
StrideCs
)
int
init_method
,
bool
do_log
,
int
nrepeat
,
std
::
vector
<
int
>
Ms
,
std
::
vector
<
int
>
Ns
,
std
::
vector
<
int
>
Ks
,
std
::
vector
<
int
>
StrideAs
,
std
::
vector
<
int
>
StrideBs
,
std
::
vector
<
int
>
StrideCs
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -65,41 +69,48 @@ void profile_grouped_gemm_impl(int do_verification,
}
};
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
// int A_size = 0, B_size = 0, C_size = 0;
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
a_m_k
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{})));
b_k_n
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
c_m_n
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
a_m_k
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{})));
b_k_n
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
c_m_n_device_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
std
::
cout
<<
"a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n["
<<
i
<<
"]:"
<<
c_m_n
[
i
].
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_device_results["
<<
i
<<
"]:"
<<
c_m_n_device_results
[
i
].
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
case
0
:
break
;
case
1
:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
// set zero to c_device_buf
c_m_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
CDataType
>
{},
num_thread
);
}
c_m_n_device_results
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
CDataType
>
{},
num_thread
);
// A_size += a_m_k[i].mDesc.GetElementSpace();
// B_size += b_k_n[i].mDesc.GetElementSpace();
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -114,28 +125,112 @@ void profile_grouped_gemm_impl(int do_verification,
// }
std
::
vector
<
DeviceMem
>
a_device_buf
,
b_device_buf
,
c_device_buf
;
//DeviceMem a_device_buf(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace());
//DeviceMem b_device_buf(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace());
//DeviceMem c_device_buf(sizeof(CDataType) * c_m_n[i].mDesc.GetElementSpace());
// std::vector<DeviceMem> a_device_buf, b_device_buf, c_device_buf;
std
::
vector
<
void
*>
a_device_buf
,
b_device_buf
,
c_device_buf
;
// DeviceMem a_device_buf_(sizeof(ADataType) * A_size);
// DeviceMem b_device_buf_(sizeof(BDataType) * B_size);
// DeviceMem c_device_buf_(sizeof(CDataType) * C_size);
// std::vector<ADataType> a_tensors_data;
// std::vector<BDataType> b_tensors_data;
// std::vector<CDataType> c_tensors_data;
std
::
vector
<
GemmShape
>
gemm_shapes
;
// A_size = 0;
// B_size = 0;
// C_size = 0;
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
a_device_buf
.
push_back
(
DeviceMem
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
b_device_buf
.
push_back
(
DeviceMem
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
c_device_buf
.
push_back
(
DeviceMem
(
sizeof
(
CDataType
)
*
c_m_n
[
i
].
mDesc
.
GetElementSpace
()));
a_device_buf
[
i
].
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
b_device_buf
[
i
].
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
c_device_buf
[
i
].
ToDevice
(
c_m_n
[
i
].
mData
.
data
());
// a_tensors_data.insert(a_tensors_data.end(), a_m_k[i].mData.begin(),
// a_m_k[i].mData.end()); b_tensors_data.insert(b_tensors_data.end(),
// b_k_n[i].mData.begin(), b_k_n[i].mData.end());
// c_tensors_data.insert(c_tensors_data.end(), c_m_n_device_results[i].mData.begin(),
// c_m_n_device_results[i].mData.end());
void
*
a_device_buf_
,
*
b_device_buf_
,
*
c_device_buf_
;
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
a_device_buf_
),
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
b_device_buf_
),
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
c_device_buf_
),
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
// DeviceMem a_device_buf_(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace());
// DeviceMem b_device_buf_(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace());
// DeviceMem c_device_buf_(sizeof(CDataType) *
// c_m_n_device_results[i].mDesc.GetElementSpace());
hipGetErrorString
(
hipMemcpy
(
a_device_buf_
,
a_m_k
[
i
].
mData
.
data
(),
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
(),
hipMemcpyHostToDevice
));
hipGetErrorString
(
hipMemcpy
(
b_device_buf_
,
b_k_n
[
i
].
mData
.
data
(),
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
(),
hipMemcpyHostToDevice
));
hipGetErrorString
(
hipMemcpy
(
c_device_buf_
,
c_m_n_device_results
[
i
].
mData
.
data
(),
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
(),
hipMemcpyHostToDevice
));
// a_device_buf_.ToDevice(a_m_k[i].mData.data());
// b_device_buf_.ToDevice(b_k_n[i].mData.data());
// c_device_buf_.ToDevice(c_m_n_device_results[i].mData.data());
a_device_buf
.
push_back
(
a_device_buf_
);
b_device_buf
.
push_back
(
b_device_buf_
);
c_device_buf
.
push_back
(
c_device_buf_
);
// a_device_buf.push_back(a_device_buf_);
// b_device_buf.push_back(b_device_buf_);
// c_device_buf.push_back(c_device_buf_);
// gemm_shapes.push_back({Ms[i],
// Ns[i],
// Ks[i],
// StrideAs[i],
// StrideBs[i],
// StrideCs[i],
// a_device_buf[i].GetDeviceBuffer(),
// b_device_buf[i].GetDeviceBuffer(),
// c_device_buf[i].GetDeviceBuffer()});
// printf("%p %p %p\n",
// a_device_buf[i].GetDeviceBuffer(),
// b_device_buf[i].GetDeviceBuffer(),
// c_device_buf[i].GetDeviceBuffer());
gemm_shapes
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideCs
[
i
],
a_device_buf_
,
b_device_buf_
,
c_device_buf_
});
// A_size += a_m_k[i].mDesc.GetElementSpace();
// B_size += b_k_n[i].mDesc.GetElementSpace();
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
}
// a_device_buf_.ToDevice(a_tensors_data.data());
// b_device_buf_.ToDevice(b_tensors_data.data());
// c_device_buf_.ToDevice(c_tensors_data.data());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
...
...
@@ -143,7 +238,6 @@ void profile_grouped_gemm_impl(int do_verification,
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
#if 0
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
...
...
@@ -216,24 +310,15 @@ void profile_grouped_gemm_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
#if
0
#if
1
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
auto
argument_ptr
=
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
gemm_ptr
->
MakeArgumentPointer
(
gemm_shapes
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck::tensor_operation::element_wise::PassThrough{},
KBatch);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
...
...
@@ -243,6 +328,7 @@ void profile_grouped_gemm_impl(int do_verification,
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
#if 0
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
...
...
@@ -262,54 +348,36 @@ void profile_grouped_gemm_impl(int do_verification,
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
#endif
if
(
do_verification
)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
is_same<BDataType, ck::bhalf_t>::value &&
is_same<CDataType, ck::bhalf_t>::value)
{
Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<float> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<float> c_m_n_device_f32_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// c_tensors_data.resize(C_size);
bf16_to_f32_(a_m_k, a_f32_m_k);
bf16_to_f32_(b_k_n, b_f32_k_n);
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
// c_device_buf_.FromDevice(c_tensors_data.data());
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<float, float, float, AElementOp, BElementOp, CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
// C_size = 0;
// for(int i = 0; i < gemm_shapes.size(); i++)
//{
// memcpy(c_m_n_device_results[i].mData.data(),
// c_tensors_data.data() + C_size,
// c_m_n_device_results[i].mDesc.GetElementSpace() * sizeof(CDataType));
auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k,
b_f32_k_n,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
//}
ref_invoker.Run(ref_argument);
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
hipGetErrorString
(
hipMemcpy
(
c_m_n_device_results
[
i
].
mData
.
data
(),
c_device_buf
[
i
],
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
(),
hipMemcpyDeviceToHost
));
check_error(c_m_n_host_result, c_m_n_device_f32_result
);
// hipGetErrorString(hipFree(c_device_buf[i])
);
if(do_log)
{
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
}
}
else
{
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor(M
, N
, StrideC, CLayout{}));
f_host_tensor_descriptor
(
M
s
[
i
],
Ns
[
i
]
,
StrideC
s
[
i
]
,
CLayout
{}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
...
...
@@ -322,27 +390,30 @@ void profile_grouped_gemm_impl(int do_verification,
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
[
i
],
b_k_n
[
i
],
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
check_error(c_m_n_host_result, c_m_n_device_result);
check_error
(
c_m_n_host_result
,
c_m_n_device_result
s
[
i
]
);
if
(
do_log
)
{
// LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
//<< std::endl;
// LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") <<
// std::endl;
LogRangeAsType
<
float
>
(
std::cout << "c_
host
: ", c_m_n_
host
_result.mData, ",")
std
::
cout
<<
"c_
device
: "
,
c_m_n_
device
_result
s
[
i
]
.
mData
,
","
)
<<
std
::
endl
;
// LogRangeAsType<float>(
// std::cout << "c_host : ", c_m_n_host_result.mData, ",")
//<< std::endl;
}
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
e3a4b967
...
...
@@ -26,7 +26,7 @@ enum GemmDataType
INT8_INT8_INT8
,
// 3
};
std
::
vector
<
int
>
stringToArray
(
char
*
input
)
std
::
vector
<
int
>
stringToArray
(
char
*
input
)
{
std
::
vector
<
int
>
out
;
...
...
@@ -34,7 +34,8 @@ std::vector<int> stringToArray(char *input)
std
::
string
item
;
while
(
std
::
getline
(
in
,
item
,
','
))
{
while
(
std
::
getline
(
in
,
item
,
','
))
{
out
.
push_back
(
std
::
stoi
(
item
));
}
...
...
@@ -69,30 +70,33 @@ int profile_grouped_gemm(int argc, char* argv[])
const
auto
Ms
=
stringToArray
(
argv
[
8
]);
const
auto
Ns
=
stringToArray
(
argv
[
9
]);
const
auto
Ks
=
stringToArray
(
argv
[
10
]);
const
auto
StrideAs
=
stringToArray
(
argv
[
11
]);
const
auto
StrideBs
=
stringToArray
(
argv
[
12
]);
const
auto
StrideCs
=
stringToArray
(
argv
[
13
]);
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
std
::
cout
<<
"M: "
<<
Ms
[
i
]
<<
" N: "
<<
Ns
[
i
]
<<
" K: "
<<
Ks
[
i
]
<<
std
::
endl
;
}
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
);
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
);
}
#if 0
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment