Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0a2a25e3
Commit
0a2a25e3
authored
Jul 04, 2022
by
rocking
Browse files
Support sweep once mode if we can put k dimension data inside one block
parent
eb6405ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
8 deletions
+37
-8
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
+37
-8
No files found.
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
View file @
0a2a25e3
...
...
@@ -97,6 +97,26 @@ struct DeviceLayernorm : public BaseOperator
YDstVectorSize
,
false
>
;
using
GridwiseReduceLayernormSweepOnce
=
GridwiseLayernorm_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
YDstVectorSize
,
true
>
;
struct
Argument
:
public
Reduction
::
Argument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
...
...
@@ -151,16 +171,25 @@ struct DeviceLayernorm : public BaseOperator
const
auto
y_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
kernel_main
=
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
GridDesc_M_K
>
;
bool
sweep_once
=
x_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
const
auto
kernel_main
=
sweep_once
?
kernel_layernorm
<
GridwiseReduceLayernormSweepOnce
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
GridDesc_M_K
>
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment