Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f591ad27
Commit
f591ad27
authored
Jul 01, 2022
by
rocking
Browse files
1. Separate gamma aand beta from affine
2. Check if argument is valid
parent
8e2d0ae7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
25 deletions
+79
-25
example/24_layernorm/layernorm_blockwise.cpp
example/24_layernorm/layernorm_blockwise.cpp
+6
-3
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
+65
-16
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
+8
-6
No files found.
example/24_layernorm/layernorm_blockwise.cpp
View file @
f591ad27
...
@@ -40,9 +40,11 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
...
@@ -40,9 +40,11 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
1
,
// SliceM
1
,
// SliceM
8
,
// SliceK
8
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
1
,
// SrcVecDim (0=M, 1=K)
1
,
// SrcScalarPerVector
8
,
// SrcScalarPerVector
1
,
// AffineVecDim (0=M, 1=K)
1
,
// GammaVecDim (0=M, 1=K)
1
,
// AffineScalarPerVector
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
1
>
;
// OutScalarPerVector
1
>
;
// OutScalarPerVector
template
<
typename
XDataType
,
template
<
typename
XDataType
,
...
@@ -129,6 +131,7 @@ int main()
...
@@ -129,6 +131,7 @@ int main()
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
({
M
,
N
},
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
({
M
,
N
},
{
Stride
,
1
},
{
Stride
,
1
},
{
0
,
1
},
{
0
,
1
},
{
0
,
1
},
{
1
},
{
1
},
1e-4
,
1e-4
,
x_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
...
...
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
View file @
f591ad27
...
@@ -34,15 +34,22 @@ template <typename XDataType,
...
@@ -34,15 +34,22 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
AffineSrcVectorDim
,
index_t
GammaSrcVectorDim
,
index_t
AffineSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
>
struct
DeviceLayernorm
:
public
BaseOperator
struct
DeviceLayernorm
:
public
BaseOperator
{
{
static_assert
(
static_assert
(
((
AffineSrcVectorDim
==
0
&&
MThreadSliceSize
%
AffineSrcVectorSize
==
0
)
||
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
AffineSrcVectorDim
==
1
&&
KThreadSliceSize
%
AffineSrcVectorSize
==
0
)),
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or affine vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
((
BetaSrcVectorDim
==
0
&&
MThreadSliceSize
%
BetaSrcVectorSize
==
0
)
||
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -83,8 +90,10 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -83,8 +90,10 @@ struct DeviceLayernorm : public BaseOperator
KThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorDim
,
InSrcVectorSize
,
InSrcVectorSize
,
AffineSrcVectorDim
,
GammaSrcVectorDim
,
AffineSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
OutDstVectorSize
,
OutDstVectorSize
,
false
>
;
false
>
;
...
@@ -92,7 +101,8 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -92,7 +101,8 @@ struct DeviceLayernorm : public BaseOperator
{
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
affineStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
epsilon
,
AccDataType
epsilon
,
const
XDataType
*
p_x
,
const
XDataType
*
p_x
,
...
@@ -116,14 +126,16 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -116,14 +126,16 @@ struct DeviceLayernorm : public BaseOperator
p_gamma_
(
p_gamma
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
)
p_beta_
(
p_beta
)
{
{
affineStrides_
=
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
affineStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
}
}
AccDataType
epsilon_
;
AccDataType
epsilon_
;
const
GammaDataType
*
p_gamma_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
const
BetaDataType
*
p_beta_
;
std
::
vector
<
index_t
>
affineStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
...
@@ -133,9 +145,9 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -133,9 +145,9 @@ struct DeviceLayernorm : public BaseOperator
const
auto
in_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
in_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
gamma_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
gamma_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
affine
Strides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
inLengths_
,
arg
.
gamma
Strides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
beta_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
beta_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
affine
Strides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
inLengths_
,
arg
.
beta
Strides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
...
@@ -189,14 +201,50 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -189,14 +201,50 @@ struct DeviceLayernorm : public BaseOperator
{
{
return
false
;
return
false
;
}
}
// TODO - Check AffineSrcVectorDim and AffineSrcVectorSize
// if fastest dim is not reduced
if
constexpr
(
GammaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
gammaStrides_
[
Reduction
::
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
GammaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
reduce_lowest_length
%
GammaSrcVectorSize
!=
0
)
return
(
false
);
}
// if fastest dim is not reduced
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
Reduction
::
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
reduce_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
return
true
;
return
true
;
};
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
affineStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
vector
<
int
>
reduceDims
,
AccDataType
epsilon
,
AccDataType
epsilon
,
const
void
*
p_x
,
const
void
*
p_x
,
...
@@ -206,7 +254,8 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -206,7 +254,8 @@ struct DeviceLayernorm : public BaseOperator
{
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
inStrides
,
affineStrides
,
gammaStrides
,
betaStrides
,
reduceDims
,
reduceDims
,
epsilon
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
XDataType
*>
(
p_x
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
View file @
f591ad27
...
@@ -59,8 +59,10 @@ template <typename XDataType,
...
@@ -59,8 +59,10 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
InSrcVectorSize
,
index_t
AffineSrcVectorDim
,
index_t
GammaSrcVectorDim
,
index_t
AffineSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
OutDstVectorSize
,
index_t
OutDstVectorSize
,
bool
SweepOnce
>
bool
SweepOnce
>
struct
GridwiseLayernorm_mk_to_mk
struct
GridwiseLayernorm_mk_to_mk
...
@@ -205,8 +207,8 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -205,8 +207,8 @@ struct GridwiseLayernorm_mk_to_mk
decltype
(
thread_buffer_desc
),
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
ThreadBufferDimAccessOrder
,
Affine
SrcVectorDim
,
Gamma
SrcVectorDim
,
Affine
SrcVectorSize
,
Gamma
SrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
...
@@ -220,8 +222,8 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -220,8 +222,8 @@ struct GridwiseLayernorm_mk_to_mk
decltype
(
thread_buffer_desc
),
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
ThreadBufferDimAccessOrder
,
Affine
SrcVectorDim
,
Beta
SrcVectorDim
,
Affine
SrcVectorSize
,
Beta
SrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment