Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dba65b1c
"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "cb5617b4790aad65ecc0d8243aaf5e2a3cc09f5a"
Commit
dba65b1c
authored
Apr 14, 2022
by
rocking
Browse files
Rewrite the gridwise_elementwise_
2d as 1d version
parent
6a781e51
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
200 additions
and
230 deletions
+200
-230
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
.../ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
+49
-62
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
+149
-0
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+0
-166
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
dba65b1c
...
@@ -165,10 +165,10 @@ struct Div
...
@@ -165,10 +165,10 @@ struct Div
};
};
using
DeviceElementwiseSubExpInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceElementwiseSubExpInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub_Exp
,
1
6
,
16
,
8
,
8
,
1
,
1
,
1
,
1
,
1
>
;
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub_Exp
,
25
6
,
32
,
8
>
;
using
DeviceElementwiseDivInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceElementwiseDivInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Div
,
1
6
,
16
,
8
,
8
,
1
,
1
,
1
,
1
,
1
>
;
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Div
,
25
6
,
32
,
8
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
...
...
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
View file @
dba65b1c
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "device.hpp"
#include "device.hpp"
#include "device_elementwise.hpp"
#include "device_elementwise.hpp"
#include "gridwise_elementwise_
2
d.hpp"
#include "gridwise_elementwise_
1
d.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -14,48 +14,40 @@ template <typename ADataType,
...
@@ -14,48 +14,40 @@ template <typename ADataType,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
ElementwiseFunctor
,
typename
ElementwiseFunctor
,
index_t
MThreadPerBlock
,
index_t
ThreadPerBlock
,
index_t
NThreadPerBlock
,
index_t
ThreadTileSize
,
index_t
MThreadTileSize
,
index_t
ScalarPerVector
>
index_t
NThreadTileSize
,
index_t
AThreadTransferSrcVectorDim
,
index_t
AThreadTransferSrcScalarPerVector
,
index_t
BThreadTransferSrcVectorDim
,
index_t
BThreadTransferSrcScalarPerVector
,
index_t
CThreadTransferSrcScalarPerVector
>
struct
DeviceElementwise_2D
:
public
DeviceElementwise
<
ElementwiseFunctor
>
struct
DeviceElementwise_2D
:
public
DeviceElementwise
<
ElementwiseFunctor
>
{
{
static_assert
(
NThreadTileSize
%
AThreadTransferSrcScalarPerVector
==
0
&&
static_assert
(
ThreadTileSize
%
ScalarPerVector
==
0
);
NThreadTileSize
%
BThreadTransferSrcScalarPerVector
==
0
);
static
constexpr
int
BlockTileSize
=
ThreadPerBlock
*
ThreadTileSize
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
)
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
auto
Make2dDescriptor_M_N
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
shape
[
0
],
shape
[
1
]),
const
int
m
=
shape
[
0
];
make_tuple
(
stride
[
0
],
stride
[
1
]));
const
int
n
=
shape
[
1
];
// 2d desc - [m, n]
const
auto
desc_m_n
=
make_naive_tensor_descriptor
(
make_tuple
(
m
,
n
),
make_tuple
(
stride
[
0
],
stride
[
1
]));
// 1d desc - [m * n]
return
transform_tensor_descriptor
(
desc_m_n
,
make_tuple
(
make_merge_transform
(
make_tuple
(
m
,
n
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
}
static
constexpr
index_t
BlockSize
=
MThreadPerBlock
*
NThreadPerBlock
;
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
}));
static
constexpr
int
M_BlockTileSize
=
MThreadPerBlock
*
MThreadTileSize
;
using
GridwiseEltwise
=
GridwiseElementwise_1D
<
ADataType
,
static
constexpr
int
N_BlockTileSize
=
NThreadPerBlock
*
NThreadTileSize
;
using
GridDesc_M_N
=
decltype
(
Make2dDescriptor_M_N
({
1
,
1
},
{
1
,
1
}));
using
GridwiseEltwise
=
GridwiseElementwise_2D
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
GridDesc_M
_N
,
GridDesc_M
0
,
ElementwiseFunctor
,
ElementwiseFunctor
,
MThreadPerBlock
,
ThreadPerBlock
,
NThreadPerBlock
,
ThreadTileSize
,
MThreadTileSize
,
ScalarPerVector
>
;
NThreadTileSize
,
AThreadTransferSrcVectorDim
,
AThreadTransferSrcScalarPerVector
,
BThreadTransferSrcVectorDim
,
BThreadTransferSrcScalarPerVector
,
CThreadTransferSrcScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -70,9 +62,9 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
...
@@ -70,9 +62,9 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
:
p_a_
(
p_a
),
:
p_a_
(
p_a
),
p_b_
(
p_b
),
p_b_
(
p_b
),
p_c_
(
p_c
),
p_c_
(
p_c
),
a_grid_desc_m
_n
_
(
Make
2d
Descriptor_M
_N
(
shape
,
stride_a
)),
a_grid_desc_m
0
_
(
MakeDescriptor_M
0
(
shape
,
stride_a
)),
b_grid_desc_m
_n
_
(
Make
2d
Descriptor_M
_N
(
shape
,
stride_b
)),
b_grid_desc_m
0
_
(
MakeDescriptor_M
0
(
shape
,
stride_b
)),
c_grid_desc_m
_n
_
(
Make
2d
Descriptor_M
_N
(
shape
,
stride_c
)),
c_grid_desc_m
0
_
(
MakeDescriptor_M
0
(
shape
,
stride_c
)),
functor_
(
functor
)
functor_
(
functor
)
{
{
}
}
...
@@ -80,47 +72,42 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
...
@@ -80,47 +72,42 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
const
ADataType
*
p_a_
;
const
ADataType
*
p_a_
;
const
BDataType
*
p_b_
;
const
BDataType
*
p_b_
;
CDataType
*
p_c_
;
CDataType
*
p_c_
;
GridDesc_M
_N
a_grid_desc_m
_n
_
;
GridDesc_M
0
a_grid_desc_m
0
_
;
GridDesc_M
_N
b_grid_desc_m
_n
_
;
GridDesc_M
0
b_grid_desc_m
0
_
;
GridDesc_M
_N
c_grid_desc_m
_n
_
;
GridDesc_M
0
c_grid_desc_m
0
_
;
ElementwiseFunctor
functor_
;
ElementwiseFunctor
functor_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
index_t
CalculateGridSize
(
const
GridDesc_M
_N
&
grid_desc_m
_n
)
index_t
CalculateGridSize
(
const
GridDesc_M
0
&
grid_desc_m
0
)
{
{
const
auto
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
auto
gridTileSize
=
grid_desc_m0
.
GetLength
(
I0
);
const
auto
N
=
grid_desc_m_n
.
GetLength
(
I1
);
return
gridTileSize
/
BlockTileSize
;
assert
(
M
%
M_BlockTileSize
==
0
);
assert
(
N
%
N_BlockTileSize
==
0
);
return
(
M
/
M_BlockTileSize
)
*
(
N
/
N_BlockTileSize
);
}
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
const
auto
kernel
=
kernel_elementwise_
2
d
<
GridwiseEltwise
,
const
auto
kernel
=
kernel_elementwise_
1
d
<
GridwiseEltwise
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
GridDesc_M
_N
,
GridDesc_M
0
,
ElementwiseFunctor
>
;
ElementwiseFunctor
>
;
float
avgTime
=
0
;
float
avgTime
=
0
;
const
index_t
gridSize
=
CalculateGridSize
(
arg
.
c_grid_desc_m
_n
_
);
const
index_t
gridSize
=
CalculateGridSize
(
arg
.
c_grid_desc_m
0
_
);
if
(
nrepeat
==
0
)
if
(
nrepeat
==
0
)
{
{
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
dim3
(
gridSize
),
dim3
(
gridSize
),
dim3
(
Block
Size
),
dim3
(
ThreadPer
Block
),
0
,
0
,
arg
.
p_a_
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_b_
,
arg
.
p_c_
,
arg
.
p_c_
,
arg
.
a_grid_desc_m
_n
_
,
arg
.
a_grid_desc_m
0
_
,
arg
.
b_grid_desc_m
_n
_
,
arg
.
b_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
_n
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
functor_
);
arg
.
functor_
);
}
}
else
else
...
@@ -128,14 +115,14 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
...
@@ -128,14 +115,14 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
avgTime
=
launch_and_time_kernel
(
kernel
,
avgTime
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
gridSize
),
dim3
(
gridSize
),
dim3
(
Block
Size
),
dim3
(
ThreadPer
Block
),
0
,
0
,
arg
.
p_a_
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_b_
,
arg
.
p_c_
,
arg
.
p_c_
,
arg
.
a_grid_desc_m
_n
_
,
arg
.
a_grid_desc_m
0
_
,
arg
.
b_grid_desc_m
_n
_
,
arg
.
b_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
_n
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
functor_
);
arg
.
functor_
);
}
}
return
avgTime
;
return
avgTime
;
...
@@ -154,10 +141,10 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
...
@@ -154,10 +141,10 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
if
(
pArg
==
nullptr
)
if
(
pArg
==
nullptr
)
return
false
;
return
false
;
const
auto
M
=
pArg
->
c_grid_desc_m_n_
.
GetLength
(
I0
);
// m * n
const
auto
N
=
pArg
->
c_grid_desc_m
_n
_
.
GetLength
(
I
1
);
const
auto
m0
=
pArg
->
c_grid_desc_m
0
_
.
GetLength
(
I
0
);
if
(
M
%
M_BlockTileSize
!=
0
&&
N
%
N_
BlockTileSize
!=
0
)
if
(
m0
%
BlockTileSize
!=
0
)
return
false
;
return
false
;
return
true
;
return
true
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
0 → 100644
View file @
dba65b1c
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
>
__global__
void
kernel_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
GridDesc_M0
c_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
GridwiseEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_grid_desc_m0
,
b_grid_desc_m0
,
c_grid_desc_m0
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ThreadPerBlock
,
index_t
ThreadTileSize
,
index_t
ScalarPerVector
>
struct
GridwiseElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
int
BlockTileSize
=
ThreadPerBlock
*
ThreadTileSize
;
static
constexpr
auto
thread_desc_M0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
__host__
auto
CalculateElementwiseIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
return
make_multi_index
(
block_id
*
BlockTileSize
+
thread_id
*
ScalarPerVector
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
GridDesc_M0
c_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m0
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m0
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m0
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
CDataType
,
ScalarPerVector
,
true
>
c_thread_buf
;
const
auto
thread_to_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ADataType
,
GridDesc_M0
,
decltype
(
thread_desc_M0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m0
,
thread_to_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
BDataType
,
GridDesc_M0
,
decltype
(
thread_desc_M0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m0
,
thread_to_global_offset
};
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
CDataType
,
CDataType
,
decltype
(
thread_desc_M0
),
GridDesc_M0
,
PassThrough
,
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
c_grid_desc_m0
,
thread_to_global_offset
,
PassThrough
{}};
int
num_iter
=
ThreadTileSize
/
ScalarPerVector
;
constexpr
auto
thread_to_global_step
=
make_multi_index
(
ThreadPerBlock
*
ScalarPerVector
);
do
{
// read and process ScalarPerVector elements
a_global_load
.
Run
(
a_grid_desc_m0
,
a_global_buf
,
thread_desc_M0
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m0
,
b_global_buf
,
thread_desc_M0
,
make_tuple
(
I0
),
b_thread_buf
);
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_M0
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
c_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}));
});
c_global_write
.
Run
(
thread_desc_M0
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
c_thread_buf
,
c_grid_desc_m0
,
c_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
thread_to_global_step
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m0
,
thread_to_global_step
);
c_global_write
.
MoveDstSliceWindow
(
c_grid_desc_m0
,
thread_to_global_step
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
deleted
100644 → 0
View file @
6a781e51
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GridDesc_M_N
,
typename
ElementwiseFunctor
>
__global__
void
kernel_elementwise_2d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M_N
a_grid_desc_m_k
,
const
GridDesc_M_N
b_grid_desc_m_k
,
const
GridDesc_M_N
c_grid_desc_m_k
,
const
ElementwiseFunctor
functor
)
{
GridwiseEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_grid_desc_m_k
,
b_grid_desc_m_k
,
c_grid_desc_m_k
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GridDesc_M_N
,
typename
ElementwiseFunctor
,
index_t
MThreadPerBlock
,
index_t
NThreadPerBlock
,
index_t
MThreadTileSize
,
index_t
NThreadTileSize
,
index_t
AThreadTransferSrcVectorDim
,
index_t
AThreadTransferSrcScalarPerVector
,
index_t
BThreadTransferSrcVectorDim
,
index_t
BThreadTransferSrcScalarPerVector
,
index_t
CThreadTransferSrcScalarPerVector
>
struct
GridwiseElementwise_2D
{
static
constexpr
auto
thread_buf_desc_M_N
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadTileSize
>
{},
Number
<
NThreadTileSize
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
ThreadBufDesc_M_N
=
decltype
(
thread_buf_desc_M_N
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
int
M_BlockTileSize
=
MThreadPerBlock
*
MThreadTileSize
;
static
constexpr
int
N_BlockTileSize
=
NThreadPerBlock
*
NThreadTileSize
;
static
__device__
__host__
auto
CalculateElementwiseIndex
(
const
GridDesc_M_N
&
grid_desc_m_n
)
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
index_t
gridSize_m
=
M
/
M_BlockTileSize
;
const
index_t
block_2d_idx_m
=
block_id
%
gridSize_m
;
const
index_t
block_2d_idx_n
=
block_id
/
gridSize_m
;
constexpr
auto
thread_desc
=
make_cluster_descriptor
(
Sequence
<
MThreadPerBlock
,
NThreadPerBlock
>
{},
Sequence
<
1
,
0
>
{});
const
auto
thread_2d_idx
=
thread_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
return
make_multi_index
(
block_2d_idx_m
*
M_BlockTileSize
+
thread_2d_idx
[
I0
]
*
MThreadTileSize
,
block_2d_idx_n
*
N_BlockTileSize
+
thread_2d_idx
[
I1
]
*
NThreadTileSize
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M_N
a_grid_desc_m_n
,
const
GridDesc_M_N
b_grid_desc_m_n
,
const
GridDesc_M_N
c_grid_desc_m_n
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m_n
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m_n
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
,
MThreadTileSize
*
NThreadTileSize
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
,
MThreadTileSize
*
NThreadTileSize
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
CDataType
,
MThreadTileSize
*
NThreadTileSize
,
true
>
c_thread_buf
;
const
auto
a_global_load_offset
=
CalculateElementwiseIndex
(
a_grid_desc_m_n
);
const
auto
b_global_load_offset
=
CalculateElementwiseIndex
(
b_grid_desc_m_n
);
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ADataType
,
GridDesc_M_N
,
decltype
(
thread_buf_desc_M_N
),
Sequence
<
MThreadTileSize
,
NThreadTileSize
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
AThreadTransferSrcVectorDim
,
AThreadTransferSrcScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m_n
,
a_global_load_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
BDataType
,
GridDesc_M_N
,
decltype
(
thread_buf_desc_M_N
),
Sequence
<
MThreadTileSize
,
NThreadTileSize
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
BThreadTransferSrcVectorDim
,
BThreadTransferSrcScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m_n
,
b_global_load_offset
};
a_global_load
.
Run
(
a_grid_desc_m_n
,
a_global_buf
,
thread_buf_desc_M_N
,
make_tuple
(
I0
,
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m_n
,
b_global_buf
,
thread_buf_desc_M_N
,
make_tuple
(
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
MThreadTileSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadTileSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
offset
=
thread_buf_desc_M_N
.
CalculateOffset
(
make_tuple
(
m
,
n
));
functor
(
c_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}));
});
});
// TODO - global write
const
auto
c_global_write_offset
=
CalculateElementwiseIndex
(
c_grid_desc_m_n
);
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
CDataType
,
CDataType
,
decltype
(
thread_buf_desc_M_N
),
GridDesc_M_N
,
PassThrough
,
Sequence
<
MThreadTileSize
,
NThreadTileSize
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
1
,
// DstVectorDim
CThreadTransferSrcScalarPerVector
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
// DstInMemOp
1
,
// DstScalarStrideInVector
false
>
{
c_grid_desc_m_n
,
c_global_write_offset
,
PassThrough
{}};
c_global_write
.
Run
(
thread_buf_desc_M_N
,
make_tuple
(
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m_n
,
c_global_buf
);
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment