Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
208ac1a5
Commit
208ac1a5
authored
May 18, 2022
by
myamlak
Browse files
Consuming binary ops to do A+B / A-B
parent
5e104742
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
166 additions
and
4 deletions
+166
-4
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+133
-4
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+33
-0
No files found.
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
208ac1a5
...
@@ -9,6 +9,8 @@
...
@@ -9,6 +9,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "binary_element_wise_operation.hpp"
#include "gridwise_binary_elementwise_1d.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -66,6 +68,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -66,6 +68,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
ScalarPerVector
=
Number
<
4
>
{};
template
<
typename
Desc_M0
>
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
threadPerBlock
)
{
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
threadPerBlock
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m0_pad
;
}
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
index_t
gridSize
,
index_t
threadPerBlock
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
2
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
2
>
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
const
auto
desc_m0
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
2
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
threadPerBlock
);
}
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
@@ -333,6 +370,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -333,6 +370,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
},
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
@@ -426,6 +464,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -426,6 +464,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
}
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_grid_desc_m_n_
);
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
c_grid_desc_m0_
=
DeviceOp
::
MakeDescriptor_M0
({
MRaw
,
NRaw
},
{
StrideC
,
I1
},
grid_size
,
BlockSize
);
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
c_grid_desc_m0_
=
DeviceOp
::
MakeDescriptor_M0
({
MRaw
,
NRaw
},
{
I1
,
StrideC
},
grid_size
,
BlockSize
);
}
}
}
// private:
// private:
...
@@ -440,6 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -440,6 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
GridDesc_M0
c_grid_desc_m0_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -468,6 +520,35 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -468,6 +520,35 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
Substract
=
ck
::
tensor_operation
::
binary_element_wise
::
Substract
;
using
GridwiseBinAdd
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
Add
,
ScalarPerVector
>
;
using
GridwiseBinSubstract
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
Substract
,
ScalarPerVector
>
;
const
auto
add_kernel
=
kernel_elementwise_1d
<
GridwiseBinAdd
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
Add
>
;
const
auto
substract_kernel
=
kernel_elementwise_1d
<
GridwiseBinSubstract
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
Substract
>
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
...
@@ -517,7 +598,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -517,7 +598,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// c_real = aux - aux_2 needed here!!!
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
substract_kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_c_grid_real_
,
arg
.
c_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
Substract
{});
ave_time
+=
ave_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -553,7 +646,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -553,7 +646,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// c_imag = aux + aux_2 needed here!!!
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
add_kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_c_grid_imag_
,
arg
.
c_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
Add
{});
}
}
else
else
{
{
...
@@ -604,7 +709,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -604,7 +709,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// // c_real = aux - aux_2 needed here!!!
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
substract_kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_c_grid_real_
,
arg
.
c_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
Substract
{});
ave_time
+=
ave_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -640,7 +757,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -640,7 +757,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// c_imag = aux + aux_2 needed here!!!
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
add_kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_c_grid_imag_
,
arg
.
c_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
Add
{});
}
}
return
ave_time
;
return
ave_time
;
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
208ac1a5
...
@@ -12,6 +12,39 @@ struct Add
...
@@ -12,6 +12,39 @@ struct Add
{
{
dst
=
src1
+
src2
;
dst
=
src1
+
src2
;
}
}
__host__
__device__
constexpr
void
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
{
dst
=
src1
+
src2
;
}
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
dst
=
src1
+
src2
;
}
};
struct
Substract
{
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
dst
=
src1
-
src2
;
}
__host__
__device__
constexpr
void
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
{
dst
=
src1
-
src2
;
}
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
dst
=
src1
-
src2
;
}
};
};
}
// namespace binary_element_wise
}
// namespace binary_element_wise
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment