Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2066a3d4
"vscode:/vscode.git/clone" did not exist on "6790b8f3cc4fd32a9d9a43c6c9d80b826d969980"
Commit
2066a3d4
authored
Nov 20, 2021
by
Chao Liu
Browse files
add pointwise operation to A/B matrix
parent
496e2ec6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
201 additions
and
72 deletions
+201
-72
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+13
-6
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+40
-16
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+8
-7
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
+24
-9
device_operation/include/device_gemm.hpp
device_operation/include/device_gemm.hpp
+10
-3
device_operation/include/device_gemm_xdl.hpp
device_operation/include/device_gemm_xdl.hpp
+48
-10
example/1_gemm_xdl/gemm_xdl.cpp
example/1_gemm_xdl/gemm_xdl.cpp
+48
-19
host/host_tensor/include/host_gemm.hpp
host/host_tensor/include/host_gemm.hpp
+10
-2
No files found.
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
2066a3d4
...
...
@@ -14,6 +14,7 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
typename
SrcElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
...
...
@@ -39,12 +40,17 @@ struct BlockwiseTensorSliceTransfer_v4
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
__device__
constexpr
BlockwiseTensorSliceTransfer_v4
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
src_element_op
)
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
SrcDesc
>>::
GetNumOfDimension
()
&&
...
...
@@ -147,6 +153,7 @@ struct BlockwiseTensorSliceTransfer_v4
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r2
<
ThreadSliceLengths
,
SrcElementwiseOperation
,
DstInMemOp
,
SrcData
,
DstData
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
2066a3d4
...
...
@@ -19,9 +19,11 @@ template <typename GridwiseGemm,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
typename
CElementwiseOperation
>
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -33,8 +35,10 @@ __global__ void
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
Block2CTileMap
block_2_ctile_map
,
const
CElementwiseOperation
c_element_op
)
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
@@ -48,8 +52,10 @@ __global__ void
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
block_2_ctile_map
,
c_element_op
);
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
...
...
@@ -58,8 +64,10 @@ template <typename GridwiseGemm,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
Block2CTileMap
,
typename
CElementwiseOperation
>
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -70,8 +78,10 @@ __global__ void
const
void
CONSTANT
*
p_a_grid_desc_k0_m_k1
,
const
void
CONSTANT
*
p_b_grid_desc_k0_n_k1
,
const
void
CONSTANT
*
p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
void
CONSTANT
*
p_block_2_ctile_map
,
const
void
CONSTANT
*
p_c_element_op
)
const
void
CONSTANT
*
p_a_element_op
,
const
void
CONSTANT
*
p_b_element_op
,
const
void
CONSTANT
*
p_c_element_op
,
const
void
CONSTANT
*
p_block_2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
@@ -85,6 +95,10 @@ __global__ void
cast_pointer_to_generic_address_space
(
p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
));
const
auto
block_2_ctile_map
=
*
reinterpret_cast
<
const
Block2CTileMap
*>
(
cast_pointer_to_generic_address_space
(
p_block_2_ctile_map
));
const
auto
a_element_op
=
*
reinterpret_cast
<
const
AElementwiseOperation
*>
(
cast_pointer_to_generic_address_space
(
p_a_element_op
));
const
auto
b_element_op
=
*
reinterpret_cast
<
const
BElementwiseOperation
*>
(
cast_pointer_to_generic_address_space
(
p_b_element_op
));
const
auto
c_element_op
=
*
reinterpret_cast
<
const
CElementwiseOperation
*>
(
cast_pointer_to_generic_address_space
(
p_c_element_op
));
...
...
@@ -97,8 +111,10 @@ __global__ void
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
block_2_ctile_map
,
c_element_op
);
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
#endif
...
...
@@ -110,6 +126,8 @@ template <index_t BlockSize,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
...
...
@@ -362,8 +380,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
CElementwiseOperation
&
c_element_op
)
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
...
...
@@ -421,6 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
AElementwiseOperation
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
...
...
@@ -442,11 +463,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
true
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
),
a_element_op
);
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BElementwiseOperation
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
...
...
@@ -468,7 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
true
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_block_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
),
b_element_op
);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
2066a3d4
...
...
@@ -50,7 +50,7 @@ template <typename SrcData,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOp
,
typename
Src
ElementwiseOp
eration
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
...
...
@@ -69,11 +69,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
ElementwiseOp
element_op
)
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
SrcElementwiseOperation
src_element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
element_op_
{
element_op
}
src_
element_op_
{
src_
element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
...
...
@@ -200,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// apply element-wise operation and type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
element_op_
(
src_buf
[
Number
<
src_offset
>
{}]));
type_convert
<
DstData
>
(
src_
element_op_
(
src_buf
[
Number
<
src_offset
>
{}]));
});
const
bool
is_dst_valid
=
...
...
@@ -377,7 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
DstCoord
dst_coord_
;
ElementwiseOp
element_op_
;
Src
ElementwiseOp
eration
src_
element_op_
;
};
// namespace ck
// Assume:
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
View file @
2066a3d4
...
...
@@ -46,6 +46,7 @@ struct lambda_scalar_per_access_for_src_and_dst
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template
<
typename
SliceLengths
,
typename
SrcElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
SrcData
,
typename
DstData
,
...
...
@@ -76,12 +77,15 @@ struct ThreadwiseTensorSliceTransfer_v3r2
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
)
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
src_element_op_
(
src_element_op
)
{
}
...
...
@@ -191,12 +195,22 @@ struct ThreadwiseTensorSliceTransfer_v3r2
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
using
src_vector_t
=
typename
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>::
type
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
// copy data from src_buf to src_thread_scratch_
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
// apply SrcElementwiseOperation on src_vector_container
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
src_vector_container
.
template
AsType
<
SrcData
>()(
i
)
=
src_element_op_
(
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_
.
template
SetAsType
<
src_vector_t
>(
src_data_idx_seq
,
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
));
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
...
...
@@ -796,6 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
SrcElementwiseOperation
src_element_op_
;
};
}
// namespace ck
...
...
device_operation/include/device_gemm.hpp
View file @
2066a3d4
...
...
@@ -8,7 +8,9 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
CElementwiseOperation
>
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
@@ -21,13 +23,18 @@ struct DeviceGemm : public BaseOperator
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
CElementwiseOperation
>
using
DeviceGemmPtr
=
std
::
unique_ptr
<
DeviceGemm
<
CElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmPtr
=
std
::
unique_ptr
<
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
device_operation/include/device_gemm_xdl.hpp
View file @
2066a3d4
...
...
@@ -22,6 +22,8 @@ template <typename ADataType,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
...
...
@@ -50,7 +52,8 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
struct
DeviceGemmXdl
:
public
DeviceGemm
<
CElementwiseOperation
>
struct
DeviceGemmXdl
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -177,6 +180,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
...
...
@@ -233,6 +238,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
...
...
@@ -244,6 +251,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
...
...
@@ -271,6 +280,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
...
...
@@ -321,9 +332,11 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdl
::
Block2CTileMap
>
,
true
,
CElementwiseOperation
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
...
...
@@ -336,8 +349,10 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
block_2_ctile_map_
,
arg
.
c_element_op_
);
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
else
{
...
...
@@ -348,9 +363,11 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdl
::
Block2CTileMap
>
,
false
,
CElementwiseOperation
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
...
...
@@ -363,8 +380,10 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
block_2_ctile_map_
,
arg
.
c_element_op_
);
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
return
ave_time
;
...
...
@@ -407,9 +426,24 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
c_element_op
};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -424,6 +458,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
@@ -437,6 +473,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
...
...
example/1_gemm_xdl/gemm_xdl.cpp
View file @
2066a3d4
...
...
@@ -14,13 +14,22 @@
#include "device_base.hpp"
#include "device_gemm_xdl.hpp"
struct
Activation
struct
Equal
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
struct
Relu
{
float
alpha
=
0.1
;
// ReLU
template
<
typename
T
>
__host__
__device__
T
operator
()(
T
v
)
const
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
T
tmp
=
alpha
*
v
;
return
tmp
>
0
?
tmp
:
0
;
...
...
@@ -33,16 +42,22 @@ template <typename ADataType,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmInstance
;
template
<
typename
CElementwiseOperation
>
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmInstance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
using
F16
=
ck
::
half_t
;
...
...
@@ -54,24 +69,32 @@ struct DeviceGemmInstance<ck::half_t,
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
AOp
=
AElementwiseOperation
;
using
BOp
=
BElementwiseOperation
;
using
COp
=
CElementwiseOperation
;
// Compilation parameters for NT problem
// clang-format off
using
type
=
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
C
Elementwise
Operation
| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | |
| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
CElementwiseOperation
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A
Elementwise
| BElementwise| CElementwise
| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | |
Operation| Operation| Operation
| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | |
| |
| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | |
| |
| | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
AOp
,
BOp
,
COp
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
// clang-format on
};
template
<
typename
CElementwiseOperation
>
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmInstance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
using
F16
=
ck
::
half_t
;
...
...
@@ -83,14 +106,18 @@ struct DeviceGemmInstance<float,
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
AOp
=
AElementwiseOperation
;
using
BOp
=
BElementwiseOperation
;
using
COp
=
CElementwiseOperation
;
// Compilation parameters for NT problem
// clang-format off
using
type
=
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
C
Elementwise
Operation
| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | |
| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
CElementwiseOperation
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
;
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A
Elementwise
| BElementwise| CElementwise
| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | |
Operation| Operation| Operation
| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | |
| |
| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | |
| |
| | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
AOp
,
BOp
,
COp
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
;
// clang-format on
};
...
...
@@ -177,9 +204,9 @@ int main(int argc, char* argv[])
ALayout
,
BLayout
,
CLayout
,
Activation
>::
type
{};
auto
activation
=
Activation
{};
Equal
,
Equal
,
Relu
>::
type
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -191,7 +218,9 @@ int main(int argc, char* argv[])
StrideA
,
StrideB
,
StrideC
,
activation
);
Equal
{},
Equal
{},
Relu
{});
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -217,7 +246,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
activation
);
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
Equal
{},
Equal
{},
Relu
{}
);
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
}
...
...
host/host_tensor/include/host_gemm.hpp
View file @
2066a3d4
#pragma once
#include "host_tensor.hpp"
template
<
typename
AType
,
typename
BType
,
typename
CType
,
typename
CElementwiseOperation
>
template
<
typename
AType
,
typename
BType
,
typename
CType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
host_gemm_mk_kn_mn
(
const
Tensor
<
AType
>&
a_m_k
,
const
Tensor
<
BType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
...
...
@@ -14,7 +21,8 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a_m_k
(
m
,
k
))
*
static_cast
<
const
double
>
(
b_k_n
(
k
,
n
));
v
+=
static_cast
<
const
double
>
(
a_element_op
(
a_m_k
(
m
,
k
)))
*
static_cast
<
const
double
>
(
b_element_op
(
b_k_n
(
k
,
n
)));
}
c_m_n
(
m
,
n
)
=
c_element_op
(
v
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment