Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
51fc99a8
Commit
51fc99a8
authored
Aug 11, 2022
by
Anthony Chang
Browse files
adds acc0 elementwise op to interface
parent
8672733f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
13 deletions
+50
-13
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
+12
-8
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
.../tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
+4
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+17
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+4
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+13
-2
No files found.
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
View file @
51fc99a8
...
...
@@ -48,10 +48,11 @@ using B0Layout = Col;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -68,6 +69,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmDefault
,
...
...
@@ -287,10 +289,11 @@ int main(int argc, char* argv[])
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
...
...
@@ -315,6 +318,7 @@ int main(int argc, char* argv[])
BatchStrideC
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
View file @
51fc99a8
...
...
@@ -22,6 +22,7 @@ template <typename ALayout,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemmGemm
:
public
BaseOperator
...
...
@@ -46,6 +47,7 @@ struct DeviceBatchedGemmGemm : public BaseOperator
ck
::
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
...
...
@@ -62,6 +64,7 @@ template <typename ALayout,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmGemmPtr
=
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
ALayout
,
...
...
@@ -74,6 +77,7 @@ using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
Acc0ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>>
;
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
51fc99a8
...
...
@@ -25,6 +25,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
...
...
@@ -45,6 +46,7 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
...
...
@@ -78,6 +80,7 @@ __global__ void
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
...
...
@@ -92,6 +95,7 @@ __global__ void
ignore
=
p_c_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
...
...
@@ -119,6 +123,7 @@ template <typename ALayout,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
...
...
@@ -173,6 +178,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
{
...
...
@@ -550,6 +556,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -624,6 +631,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
...
...
@@ -639,6 +647,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
...
...
@@ -670,6 +679,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
...
...
@@ -708,6 +718,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
...
...
@@ -729,6 +740,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
...
...
@@ -807,14 +819,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_b1
,
p_c
,
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideB
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
b1
_element_op
,
c_element_op
};
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc
_element_op
,
b1_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -839,6 +852,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
{
...
...
@@ -861,6 +875,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
51fc99a8
...
...
@@ -23,6 +23,7 @@ template <typename FloatAB,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
...
@@ -319,6 +320,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
AccElementwiseOperation
&
acc_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
...
...
@@ -544,10 +546,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
FloatAB
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
acc_element_op
),
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{};
n4
>
{
acc_element_op
};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
51fc99a8
...
...
@@ -1202,6 +1202,7 @@ template <typename SrcData,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
...
...
@@ -1214,7 +1215,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic
()
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic
(
const
ElementwiseOperation
&
element_op
)
:
element_op_
{
element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
...
...
@@ -1277,10 +1280,18 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
src_buf
[
Number
<
src_offset
>
{}];
SrcData
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
);
});
});
}
ElementwiseOperation
element_op_
;
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment