Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
12673f3f
Commit
12673f3f
authored
Sep 14, 2022
by
rocking
Browse files
Let shape of gamma and beta can be same as x
parent
45220e05
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
183 additions
and
206 deletions
+183
-206
example/27_layernorm/layernorm_blockwise.cpp
example/27_layernorm/layernorm_blockwise.cpp
+23
-20
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
.../ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
+59
-85
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
..._operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
+52
-52
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
...peration/gpu/grid/gridwise_layernorm_welford_variance.hpp
+49
-49
No files found.
example/27_layernorm/layernorm_blockwise.cpp
View file @
12673f3f
...
...
@@ -29,7 +29,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceLayernormImpl
<
XDataType
,
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceLayernormImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
...
...
@@ -44,7 +45,9 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernormImpl<XDataTy
8
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
...
...
@@ -88,8 +91,8 @@ int main()
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
{
M
,
N
},
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
gamma
.
mDesc
.
GetStrides
().
begin
(),
gamma
.
mDesc
.
GetStrides
().
end
()
},
std
::
vector
<
ck
::
index_t
>
{
beta
.
mDesc
.
GetStrides
().
begin
(),
beta
.
mDesc
.
GetStrides
().
end
()
},
{
0
,
1
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
{
1
},
1e-4
,
...
...
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
View file @
12673f3f
...
...
@@ -23,11 +23,10 @@ template <typename GridwiseReduction,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
>
typename
GridDesc_M_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_K
gamma_grid_desc_k
,
const
GridDesc_K
beta_grid_desc_k
,
const
GridDesc_
M_
K
gamma_grid_desc_
m_
k
,
const
GridDesc_
M_
K
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
...
...
@@ -38,8 +37,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const
AccElementwiseOperation
acc_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_k
,
beta_grid_desc_k
,
gamma_grid_desc_
m_
k
,
beta_grid_desc_
m_
k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
epsilon
,
...
...
@@ -71,7 +70,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
XYSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
struct
DeviceLayernormImpl
:
public
DeviceLayernorm
<
XDataType
,
...
...
@@ -84,11 +85,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
NumReduceDim
>
{
static_assert
(
(
KThreadSliceSize
%
GammaSrcVectorSize
==
0
),
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
(
KThreadSliceSize
%
BetaSrcVectorSize
==
0
),
((
BetaSrcVectorDim
==
0
&&
MThreadSliceSize
%
BetaSrcVectorSize
==
0
)
||
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -162,38 +165,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeAffine1dDescriptor
(
const
std
::
vector
<
index_t
>&
Lengths
,
const
std
::
vector
<
index_t
>&
Strides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleLengths
=
make_tuple_from_array
(
Lengths
,
Number
<
NumReduceDim
>
{});
const
auto
tupleStrides
=
make_tuple_from_array
(
Strides
,
Number
<
NumReduceDim
>
{});
auto
desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_k
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumReduceDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
reduceTotalLength
=
grid_desc_k
.
GetLength
(
Number
<
0
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
Pad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceTotalLength
;
auto
grid_desc_k_padded
=
transform_tensor_descriptor
(
grid_desc_k
,
make_tuple
(
make_right_pad_transform
(
reduceTotalLength
,
Pad_K
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_K
=
decltype
(
MakeAffine1dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
...
...
@@ -203,7 +175,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -211,12 +182,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseReduceLayernormSweepOnce
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
...
...
@@ -225,7 +197,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -233,7 +204,9 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
...
...
@@ -258,13 +231,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
gammaStrides_
(
gammaStrides
),
betaStrides_
(
betaStrides
),
acc_elementwise_op_
(
acc_elementwise_op
)
{
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
long_index_t
invariant_total_length
;
long_index_t
reduce_total_length
;
...
...
@@ -277,13 +250,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
reduceLengths_
.
resize
(
NumReduceDim
);
for
(
int
i
=
0
;
i
<
NumReduceDim
;
++
i
)
{
reduceLengths_
[
i
]
=
lengths
[
reduceDims
[
i
]];
}
}
AccDataType
epsilon_
;
...
...
@@ -295,7 +261,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
reduceLengths_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
...
...
@@ -313,14 +278,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
{
const
auto
x_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
xStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
gamma_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
beta_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
gamma_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
beta_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
y_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
yStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
...
...
@@ -334,8 +295,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
>
GridDesc_M_K
>
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
XDataType
,
GammaDataType
,
...
...
@@ -343,8 +303,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
>
;
GridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -353,8 +312,8 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
dim3
(
BlockSize
),
0
,
x_grid_desc_m_k
,
gamma_grid_desc_k
,
beta_grid_desc_k
,
gamma_grid_desc_
m_
k
,
beta_grid_desc_
m_
k
,
y_grid_desc_m_k
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
...
...
@@ -409,26 +368,41 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return
false
;
}
if
(
p_arg_
->
gammaStrides_
.
size
()
!=
NumReduceDim
||
p_arg_
->
betaStrides_
.
size
()
!=
NumReduceDim
)
return
false
;
// if fastest dim is not reduced
if
constexpr
(
GammaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
bool
ret
=
true
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
!
isLastDimensionCoalesced
)
ret
=
scalarPerVector
==
1
;
else
ret
=
KThreadSliceSize
%
scalarPerVector
==
0
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
);
}
return
ret
;
};
// if fastest dim is not reduced
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
gammaStrides_
.
back
()
==
1
,
GammaSrcVectorSize
))
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
betaStrides_
.
back
()
==
1
,
BetaSrcVectorSize
))
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
return
true
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
View file @
12673f3f
...
...
@@ -22,7 +22,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -30,7 +29,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
...
...
@@ -83,8 +84,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
...
...
@@ -111,11 +112,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
...
...
@@ -127,7 +131,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
value
_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
thread
_buf
=
mean_square_thread_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
@@ -145,11 +149,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
...
...
@@ -169,27 +170,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferLengths_K
,
Sequence
<
0
>
,
0
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_
M_
K
,
ThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferLengths_K
,
Sequence
<
0
>
,
0
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_
M_
K
,
ThreadBufferDimAccessOrder
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
...
@@ -212,9 +220,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
...
...
@@ -224,10 +229,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
// E(x), E[x^2], var(x)
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
...
...
@@ -271,17 +276,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
// var(x) = E[x^2] - E[x]^2
var_
value
_buf
(
I
)
=
var_
thread
_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
reducedTiles
=
0
;
...
...
@@ -296,10 +300,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
x_thread_buf
);
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
make_tuple
(
I0
),
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
...
@@ -307,23 +311,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_
value
_buf
(
iM
)
+
epsilon
);
sqrt
(
var_
thread
_buf
(
iM
)
+
epsilon
);
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
thread_buffer_desc_k
,
make_tuple
(
I0
),
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
...
@@ -331,11 +333,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
...
...
@@ -346,8 +346,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
++
reducedTiles
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
View file @
12673f3f
...
...
@@ -19,7 +19,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -27,7 +26,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
...
...
@@ -94,8 +95,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
...
...
@@ -116,11 +117,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
...
...
@@ -137,11 +141,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
...
...
@@ -161,27 +162,34 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferLengths_K
,
Sequence
<
0
>
,
0
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_
M_
K
,
ThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferLengths_K
,
Sequence
<
0
>
,
0
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_
M_
K
,
ThreadBufferDimAccessOrder
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
...
@@ -204,9 +212,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
...
...
@@ -216,10 +221,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
...
...
@@ -250,11 +255,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
...
...
@@ -268,10 +272,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_thread_buf
);
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
make_tuple
(
I0
),
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
...
@@ -279,8 +283,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
...
...
@@ -288,14 +290,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
thread_buffer_desc_k
,
make_tuple
(
I0
),
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
...
@@ -303,11 +305,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
...
...
@@ -318,8 +318,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment