Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a760a732
Commit
a760a732
authored
Apr 12, 2022
by
rocking
Browse files
A kernel of elementwise_2d (except global store)
parent
cb1c4731
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
42 deletions
+118
-42
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+7
-7
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
.../ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
+18
-11
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+93
-24
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
a760a732
...
...
@@ -84,8 +84,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
ck
::
ReduceTensorOp
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
constexpr
ck
::
NanPropagation
NanOpt
=
ck
::
NanPropagation
::
PROPAGATE_NAN
;
constexpr
bool
PropagateNan
=
(
NanOpt
==
ck
::
NanPropagation
::
NOT_PROPAGATE_NAN
)
?
false
:
true
;
...
...
@@ -118,14 +118,14 @@ using DeviceReduceInstance =
struct
Sub
{
__host__
__device__
constexpr
void
operator
()(
F16
&
dst
,
const
F16
&
src1
,
const
F16
&
src2
)
const
__host__
__device__
constexpr
void
operator
()(
CDataType
&
dst
,
const
CDataType
&
src1
,
const
CDataType
&
src2
)
const
{
dst
=
src1
-
src2
;
}
};
using
DeviceElementwiseInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub
,
16
,
16
,
8
,
8
>
;
using
DeviceElementwiseInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub
,
16
,
16
,
8
,
8
,
1
,
1
,
1
,
1
,
1
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
...
...
@@ -302,8 +302,8 @@ int main(int argc, char* argv[])
if
(
!
broadcastSub
.
IsSupportedArgument
(
broadcastSub_argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the
DeviceElementwise_2D instance, exiting!"
);
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"
DeviceElementwise_2D instance, exiting!"
);
};
auto
broadcastSub_invoker_ptr
=
broadcastSub
.
MakeInvokerPointer
();
...
...
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
View file @
a760a732
...
...
@@ -17,9 +17,17 @@ template <typename ADataType,
index_t
MThreadPerBlock
,
index_t
NThreadPerBlock
,
index_t
MThreadTileSize
,
index_t
NThreadTileSize
>
index_t
NThreadTileSize
,
index_t
AThreadTransferSrcVectorDim
,
index_t
AThreadTransferSrcScalarPerVector
,
index_t
BThreadTransferSrcVectorDim
,
index_t
BThreadTransferSrcScalarPerVector
,
index_t
CThreadTransferSrcScalarPerVector
>
struct
DeviceElementwise_2D
:
public
DeviceElementwise
<
ElementwiseFunctor
>
{
static_assert
(
NThreadTileSize
%
AThreadTransferSrcScalarPerVector
==
0
&&
NThreadTileSize
%
BThreadTransferSrcScalarPerVector
==
0
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -38,11 +46,16 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
BDataType
,
CDataType
,
GridDesc_M_N
,
GridDesc_M_N
,
GridDesc_M_N
,
ElementwiseFunctor
,
MThreadPerBlock
,
NThreadPerBlock
,
MThreadTileSize
,
NThreadTileSize
>
;
NThreadTileSize
,
AThreadTransferSrcVectorDim
,
AThreadTransferSrcScalarPerVector
,
BThreadTransferSrcVectorDim
,
BThreadTransferSrcScalarPerVector
,
CThreadTransferSrcScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
{
...
...
@@ -88,18 +101,12 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
const
auto
kernel
=
kernel_elementwise_2d
<
GridwiseEltwise
,
const
auto
kernel
=
kernel_elementwise_2d
<
GridwiseEltwise
,
ADataType
,
BDataType
,
CDataType
,
GridDesc_M_N
,
GridDesc_M_N
,
GridDesc_M_N
,
ElementwiseFunctor
>
;
// TODO
(
void
)
arg
;
(
void
)
nrepeat
;
(
void
)
kernel
;
float
avgTime
=
0
;
const
index_t
gridSize
=
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
if
(
nrepeat
==
0
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
View file @
a760a732
...
...
@@ -10,16 +10,14 @@ template <typename GridwiseEltwise,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AGridDesc_M_N
,
typename
BGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
GridDesc_M_N
,
typename
ElementwiseFunctor
>
__global__
void
kernel_elementwise_2d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
A
GridDesc_M_N
a_grid_desc_m_k
,
const
B
GridDesc_M_N
b_grid_desc_m_k
,
const
C
GridDesc_M_N
c_grid_desc_m_k
,
const
GridDesc_M_N
a_grid_desc_m_k
,
const
GridDesc_M_N
b_grid_desc_m_k
,
const
GridDesc_M_N
c_grid_desc_m_k
,
const
ElementwiseFunctor
functor
)
{
GridwiseEltwise
::
Run
(
p_a_global
,
...
...
@@ -34,26 +32,58 @@ __global__ void kernel_elementwise_2d(const ADataType* __restrict__ p_a_global,
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AGridDesc_M_N
,
typename
BGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
GridDesc_M_N
,
typename
ElementwiseFunctor
,
index_t
MThreadPerBlock
,
index_t
NThreadPerBlock
,
index_t
MThreadTileSize
,
index_t
NThreadTileSize
>
index_t
NThreadTileSize
,
index_t
AThreadTransferSrcVectorDim
,
index_t
AThreadTransferSrcScalarPerVector
,
index_t
BThreadTransferSrcVectorDim
,
index_t
BThreadTransferSrcScalarPerVector
,
index_t
CThreadTransferSrcScalarPerVector
>
struct
GridwiseElementwise_2D
{
static
constexpr
auto
thread_buf_desc_M_N
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadTileSize
>
{},
Number
<
NThreadTileSize
>
{}));
using
ThreadBufDesc_M_N
=
decltype
(
thread_buf_desc_M_N
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
int
M_BlockTileSize
=
MThreadPerBlock
*
MThreadTileSize
;
static
constexpr
int
N_BlockTileSize
=
NThreadPerBlock
*
NThreadTileSize
;
static
__device__
__host__
auto
CalculateElementwiseIndex
(
const
GridDesc_M_N
&
grid_desc_m_n
)
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
index_t
gridSize_m
=
M
/
M_BlockTileSize
;
const
index_t
block_2d_idx_m
=
block_id
%
gridSize_m
;
const
index_t
block_2d_idx_n
=
block_id
/
gridSize_m
;
constexpr
auto
thread_desc
=
make_cluster_descriptor
(
Sequence
<
MThreadPerBlock
,
NThreadPerBlock
>
{},
Sequence
<
1
,
0
>
{});
const
auto
thread_2d_idx
=
thread_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
return
make_multi_index
(
block_2d_idx_m
*
M_BlockTileSize
+
thread_2d_idx
[
I0
]
*
MThreadTileSize
,
block_2d_idx_n
*
N_BlockTileSize
+
thread_2d_idx
[
I1
]
*
NThreadTileSize
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
A
GridDesc_M_N
a_grid_desc_m_n
,
const
B
GridDesc_M_N
b_grid_desc_m_n
,
const
C
GridDesc_M_N
c_grid_desc_m_n
,
const
GridDesc_M_N
a_grid_desc_m_n
,
const
GridDesc_M_N
b_grid_desc_m_n
,
const
GridDesc_M_N
c_grid_desc_m_n
,
const
ElementwiseFunctor
functor
)
{
// const index_t thread_id = get_thread_local_1d_id();
// const index_t block_id = get_block_1d_id();
// printf("block_id = %d, thread_id = %d \n", block_id, thread_id);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -68,14 +98,53 @@ struct GridwiseElementwise_2D
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
CDataType
,
MThreadTileSize
*
NThreadTileSize
,
true
>
c_thread_buf
;
// TODO - buffer_load, apply functor, buffer_store
(
void
)
a_global_buf
;
(
void
)
b_global_buf
;
const
auto
a_global_load_offset
=
CalculateElementwiseIndex
(
a_grid_desc_m_n
);
const
auto
b_global_load_offset
=
CalculateElementwiseIndex
(
b_grid_desc_m_n
);
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ADataType
,
GridDesc_M_N
,
decltype
(
thread_buf_desc_M_N
),
Sequence
<
MThreadTileSize
,
NThreadTileSize
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
AThreadTransferSrcVectorDim
,
AThreadTransferSrcScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m_n
,
a_global_load_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
BDataType
,
GridDesc_M_N
,
decltype
(
thread_buf_desc_M_N
),
Sequence
<
MThreadTileSize
,
NThreadTileSize
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
BThreadTransferSrcVectorDim
,
BThreadTransferSrcScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m_n
,
b_global_load_offset
};
a_global_load
.
Run
(
a_grid_desc_m_n
,
a_global_buf
,
thread_buf_desc_M_N
,
make_tuple
(
I0
,
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m_n
,
b_global_buf
,
thread_buf_desc_M_N
,
make_tuple
(
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
MThreadTileSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadTileSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
offset
=
thread_buf_desc_M_N
.
CalculateOffset
(
make_tuple
(
m
,
n
));
functor
(
c_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}));
});
});
// TODO - global write
(
void
)
c_global_buf
;
(
void
)
a_thread_buf
;
(
void
)
b_thread_buf
;
(
void
)
c_thread_buf
;
(
void
)
functor
;
// c_global_write.Run(
// thread_buf_desc_M_N, c_thread_buf, c_grid_desc_m_n, make_tuple(I0, I0),
// c_global_buf);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment