Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3e38e358
Commit
3e38e358
authored
Jul 06, 2022
by
rocking
Browse files
Add accElementwiseOp
parent
6ed9ab3a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
13 deletions
+28
-13
example/24_layernorm/layernorm_blockwise.cpp
example/24_layernorm/layernorm_blockwise.cpp
+4
-1
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
+14
-5
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
+10
-7
No files found.
example/24_layernorm/layernorm_blockwise.cpp
View file @
3e38e358
...
...
@@ -23,6 +23,7 @@ using GammaDataType = ck::half_t;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
...
...
@@ -32,6 +33,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
BetaDataType
,
AccDataType
,
YDataType
,
PassThrough
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
...
...
@@ -136,7 +138,8 @@ int main()
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
());
y_dev
.
GetDeviceBuffer
(),
PassThrough
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
...
...
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
View file @
3e38e358
...
...
@@ -25,6 +25,7 @@ template <typename XDataType,
typename
BetaDataType
,
typename
AccDataType
,
typename
YDataType
,
typename
AccElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
...
...
@@ -56,8 +57,8 @@ struct DeviceLayernorm : public BaseOperator
Rank
,
NumReduceDim
,
reduce
::
Add
,
PassThrough
,
// InElementwiseOperation
PassThrough
,
// AccElementwiseOperation
PassThrough
,
// InElementwiseOperation
AccElementwiseOperation
,
// AccElementwiseOperation
InMemoryDataOperationEnum
::
Set
,
false
,
// PropagateNan
false
,
// OutputIndex
...
...
@@ -109,6 +110,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
...
...
@@ -128,6 +130,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
...
...
@@ -149,6 +152,7 @@ struct DeviceLayernorm : public BaseOperator
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccElementwiseOperation
acc_elementwise_op
,
AccDataType
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
...
...
@@ -165,7 +169,7 @@ struct DeviceLayernorm : public BaseOperator
nullptr
,
p_y
,
nullptr
,
PassThrough
{}
,
acc_elementwise_op
,
PassThrough
{}),
epsilon_
(
epsilon
),
p_gamma_
(
p_gamma
),
...
...
@@ -211,6 +215,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
>
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
...
...
@@ -219,6 +224,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
>
;
...
...
@@ -237,7 +243,8 @@ struct DeviceLayernorm : public BaseOperator
arg
.
in_dev_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
out_dev_
);
arg
.
out_dev_
,
arg
.
acc_elementwise_op_
);
return
(
avg_time
);
};
...
...
@@ -296,13 +303,15 @@ struct DeviceLayernorm : public BaseOperator
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
)
void
*
p_y
,
AccElementwiseOperation
acc_elementwise_op
)
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
gammaStrides
,
betaStrides
,
reduceDims
,
acc_elementwise_op
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
View file @
3e38e358
...
...
@@ -20,6 +20,7 @@ template <typename GridwiseReduction,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
...
...
@@ -31,7 +32,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
)
YDataType
*
const
__restrict__
p_y_global
,
const
AccElementwiseOperation
acc_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_k
,
...
...
@@ -42,7 +44,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
);
p_y_global
,
acc_elementwise_op
);
};
template
<
typename
XDataType
,
...
...
@@ -50,6 +53,7 @@ template <typename XDataType,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
...
...
@@ -105,8 +109,6 @@ struct GridwiseLayernorm_mk_to_mk
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -122,7 +124,8 @@ struct GridwiseLayernorm_mk_to_mk
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
)
YDataType
*
const
__restrict__
p_y_global
,
const
AccElementwiseOperation
acc_elementwise_op
)
{
if
constexpr
(
SweepOnce
)
{
...
...
@@ -225,7 +228,7 @@ struct GridwiseLayernorm_mk_to_mk
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
PassThroughOp
,
AccElementwiseOperation
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
...
...
@@ -237,7 +240,7 @@ struct GridwiseLayernorm_mk_to_mk
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{}
);
acc_elementwise_op
);
// Copy x from Cache
// one pass: fwd, second pass: bwd
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment