Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
12673f3f
Commit
12673f3f
authored
Sep 14, 2022
by
rocking
Browse files
Let shape of gamma and beta can be same as x
parent
45220e05
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
183 additions
and
206 deletions
+183
-206
example/27_layernorm/layernorm_blockwise.cpp
example/27_layernorm/layernorm_blockwise.cpp
+23
-20
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
.../ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
+59
-85
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
..._operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
+52
-52
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
...peration/gpu/grid/gridwise_layernorm_welford_variance.hpp
+49
-49
No files found.
example/27_layernorm/layernorm_blockwise.cpp
View file @
12673f3f
...
@@ -29,24 +29,27 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -29,24 +29,27 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceLayernormImpl
<
XDataType
,
using
DeviceInstance
=
GammaDataType
,
ck
::
tensor_operation
::
device
::
DeviceLayernormImpl
<
XDataType
,
BetaDataType
,
GammaDataType
,
AccDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
PassThrough
,
YDataType
,
Rank
,
PassThrough
,
NumReduceDim
,
Rank
,
256
,
// BlockSize
NumReduceDim
,
8
,
// ClusterM
256
,
// BlockSize
32
,
// ClusterK
8
,
// ClusterM
1
,
// SliceM
32
,
// ClusterK
8
,
// SliceK
1
,
// SliceM
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SliceK
8
,
// SrcScalarPerVector
1
,
// SrcVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
8
,
// SrcScalarPerVector
8
,
// BetaScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
>
;
// OutScalarPerVector
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
int
main
()
int
main
()
{
{
...
@@ -88,8 +91,8 @@ int main()
...
@@ -88,8 +91,8 @@ int main()
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
{
M
,
N
},
{
M
,
N
},
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
gamma
.
mDesc
.
GetStrides
().
begin
(),
gamma
.
mDesc
.
GetStrides
().
end
()
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
beta
.
mDesc
.
GetStrides
().
begin
(),
beta
.
mDesc
.
GetStrides
().
end
()
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
{
1
},
{
1
},
1e-4
,
1e-4
,
...
...
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
View file @
12673f3f
...
@@ -23,11 +23,10 @@ template <typename GridwiseReduction,
...
@@ -23,11 +23,10 @@ template <typename GridwiseReduction,
typename
YDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
>
typename
GridDesc_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_K
gamma_grid_desc_k
,
const
GridDesc_
M_
K
gamma_grid_desc_
m_
k
,
const
GridDesc_K
beta_grid_desc_k
,
const
GridDesc_
M_
K
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -38,8 +37,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
...
@@ -38,8 +37,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const
AccElementwiseOperation
acc_elementwise_op
)
const
AccElementwiseOperation
acc_elementwise_op
)
{
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_k
,
gamma_grid_desc_
m_
k
,
beta_grid_desc_k
,
beta_grid_desc_
m_
k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
epsilon
,
epsilon
,
...
@@ -71,7 +70,9 @@ template <typename XDataType,
...
@@ -71,7 +70,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XYSrcVectorDim
,
index_t
XYSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
>
struct
DeviceLayernormImpl
:
public
DeviceLayernorm
<
XDataType
,
struct
DeviceLayernormImpl
:
public
DeviceLayernorm
<
XDataType
,
...
@@ -84,11 +85,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -84,11 +85,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
NumReduceDim
>
NumReduceDim
>
{
{
static_assert
(
static_assert
(
(
KThreadSliceSize
%
GammaSrcVectorSize
==
0
),
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
static_assert
(
(
KThreadSliceSize
%
BetaSrcVectorSize
==
0
),
((
BetaSrcVectorDim
==
0
&&
MThreadSliceSize
%
BetaSrcVectorSize
==
0
)
||
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -162,38 +165,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -162,38 +165,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
static
auto
MakeAffine1dDescriptor
(
const
std
::
vector
<
index_t
>&
Lengths
,
const
std
::
vector
<
index_t
>&
Strides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleLengths
=
make_tuple_from_array
(
Lengths
,
Number
<
NumReduceDim
>
{});
const
auto
tupleStrides
=
make_tuple_from_array
(
Strides
,
Number
<
NumReduceDim
>
{});
auto
desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_k
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumReduceDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
reduceTotalLength
=
grid_desc_k
.
GetLength
(
Number
<
0
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
Pad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceTotalLength
;
auto
grid_desc_k_padded
=
transform_tensor_descriptor
(
grid_desc_k
,
make_tuple
(
make_right_pad_transform
(
reduceTotalLength
,
Pad_K
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_K
=
decltype
(
MakeAffine1dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
using
GridwiseReduceLayernormGeneric
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
...
@@ -203,7 +175,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -203,7 +175,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType
,
AccDataType
,
AccElementwiseOperation
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -211,12 +182,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -211,12 +182,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
XSrcVectorSize
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
false
>
;
false
>
;
using
GridwiseReduceLayernormSweepOnce
=
using
GridwiseReduceLayernormSweepOnce
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
...
@@ -225,7 +197,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -225,7 +197,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType
,
AccDataType
,
AccElementwiseOperation
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -233,7 +204,9 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -233,7 +204,9 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
XSrcVectorSize
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
...
@@ -258,13 +231,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -258,13 +231,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
p_gamma_
(
p_gamma
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_y_
(
p_y
),
gammaStrides_
(
gammaStrides
),
betaStrides_
(
betaStrides
),
acc_elementwise_op_
(
acc_elementwise_op
)
acc_elementwise_op_
(
acc_elementwise_op
)
{
{
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
long_index_t
invariant_total_length
;
long_index_t
invariant_total_length
;
long_index_t
reduce_total_length
;
long_index_t
reduce_total_length
;
...
@@ -277,13 +250,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -277,13 +250,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
M_BlockTileSize
*
blkGroupSize_
;
reduceLengths_
.
resize
(
NumReduceDim
);
for
(
int
i
=
0
;
i
<
NumReduceDim
;
++
i
)
{
reduceLengths_
[
i
]
=
lengths
[
reduceDims
[
i
]];
}
}
}
AccDataType
epsilon_
;
AccDataType
epsilon_
;
...
@@ -295,7 +261,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -295,7 +261,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
reduceLengths_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
...
@@ -313,15 +278,11 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -313,15 +278,11 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
{
{
const
auto
x_grid_desc_m_k
=
MakeSrc2dDescriptor
(
const
auto
x_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
xStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
arg
.
Lengths_
,
arg
.
xStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
gamma_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
const
auto
gamma_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
gammaStrides_
,
arg
.
Lengths_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
arg
.
blkGroupSize_
,
const
auto
beta_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
numBlockTileIteration_
);
arg
.
Lengths_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
beta_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
const
auto
y_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
betaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
y_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
yStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
arg
.
Lengths_
,
arg
.
yStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
bool
sweep_once
=
bool
sweep_once
=
...
@@ -334,8 +295,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -334,8 +295,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
YDataType
,
YDataType
,
AccDataType
,
AccDataType
,
AccElementwiseOperation
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
>
GridDesc_K
>
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
...
@@ -343,8 +303,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -343,8 +303,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
YDataType
,
YDataType
,
AccDataType
,
AccDataType
,
AccElementwiseOperation
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
>
;
GridDesc_K
>
;
float
avg_time
=
0
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
@@ -353,8 +312,8 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -353,8 +312,8 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
x_grid_desc_m_k
,
x_grid_desc_m_k
,
gamma_grid_desc_k
,
gamma_grid_desc_
m_
k
,
beta_grid_desc_k
,
beta_grid_desc_
m_
k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
arg
.
numBlockTileIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
epsilon_
,
...
@@ -409,26 +368,41 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -409,26 +368,41 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return
false
;
return
false
;
}
}
if
(
p_arg_
->
gammaStrides_
.
size
()
!=
NumReduceDim
||
// if fastest dim is not reduced
p_arg_
->
betaStrides_
.
size
()
!=
NumReduceDim
)
if
constexpr
(
GammaSrcVectorDim
==
0
)
return
false
;
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
bool
ret
=
true
;
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
!
isLastDimensionCoalesced
)
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
ret
=
scalarPerVector
==
1
;
return
(
false
);
else
}
ret
=
KThreadSliceSize
%
scalarPerVector
==
0
;
return
ret
;
// if fastest dim is not reduced
};
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
gammaStrides_
.
back
()
==
1
,
GammaSrcVectorSize
))
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
false
;
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
betaStrides_
.
back
()
==
1
,
BetaSrcVectorSize
))
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
false
;
return
(
false
);
}
return
true
;
return
true
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
View file @
12673f3f
...
@@ -22,7 +22,6 @@ template <typename XDataType,
...
@@ -22,7 +22,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -30,7 +29,9 @@ template <typename XDataType,
...
@@ -30,7 +29,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
...
@@ -83,8 +84,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -83,8 +84,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -111,11 +112,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -111,11 +112,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
y_thread_buf
;
...
@@ -127,7 +131,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -127,7 +131,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_square_thread_buf
;
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
value
_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
thread
_buf
=
mean_square_thread_buf
;
mean_square_thread_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
@@ -145,11 +149,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -145,11 +149,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -169,27 +170,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -169,27 +170,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
auto
threadwise_gamma_load
=
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
AccDataType
,
GridDesc_K
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_k
),
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_K
,
ThreadBufferLengths_
M_
K
,
Sequence
<
0
>
,
ThreadBufferDimAccessOrder
,
0
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
thread_m_cluster_id
*
MThreadSliceSize
,
AccDataType
,
thread_k_cluster_id
*
KThreadSliceSize
));
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
auto
threadwise_beta_load
=
ThreadBufferLengths_K
,
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
Sequence
<
0
>
,
AccDataType
,
0
,
GridDesc_M_K
,
BetaSrcVectorSize
,
decltype
(
thread_buffer_desc_m_k
),
1
,
ThreadBufferLengths_M_K
,
true
>
(
ThreadBufferDimAccessOrder
,
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
@@ -212,9 +220,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -212,9 +220,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
// Copy x from Cache
// Copy x from Cache
// one pass: fwd, second pass: bwd
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
...
@@ -224,10 +229,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -224,10 +229,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
// E(x), E[x^2], var(x)
// E(x), E[x^2], var(x)
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
...
@@ -271,17 +276,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -271,17 +276,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
// var(x) = E[x^2] - E[x]^2
// var(x) = E[x^2] - E[x]^2
var_
value
_buf
(
I
)
=
var_
thread
_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
});
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
reducedTiles
=
0
;
reducedTiles
=
0
;
...
@@ -296,10 +300,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -296,10 +300,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
x_thread_buf
);
x_thread_buf
);
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -307,23 +311,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -307,23 +311,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_
value
_buf
(
iM
)
+
epsilon
);
sqrt
(
var_
thread
_buf
(
iM
)
+
epsilon
);
// gamma
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
beta_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -331,11 +333,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -331,11 +333,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
...
@@ -346,8 +346,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -346,8 +346,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
y_global_val_buf
);
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
++
reducedTiles
;
++
reducedTiles
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
View file @
12673f3f
...
@@ -19,7 +19,6 @@ template <typename XDataType,
...
@@ -19,7 +19,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -27,7 +26,9 @@ template <typename XDataType,
...
@@ -27,7 +26,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
...
@@ -94,8 +95,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -94,8 +95,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}
}
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -116,11 +117,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -116,11 +117,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
y_thread_buf
;
...
@@ -137,11 +141,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -137,11 +141,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -161,27 +162,34 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -161,27 +162,34 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto
threadwise_gamma_load
=
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
AccDataType
,
GridDesc_K
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_k
),
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_K
,
ThreadBufferLengths_
M_
K
,
Sequence
<
0
>
,
ThreadBufferDimAccessOrder
,
0
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
thread_m_cluster_id
*
MThreadSliceSize
,
AccDataType
,
thread_k_cluster_id
*
KThreadSliceSize
));
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
auto
threadwise_beta_load
=
ThreadBufferLengths_K
,
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
Sequence
<
0
>
,
AccDataType
,
0
,
GridDesc_M_K
,
BetaSrcVectorSize
,
decltype
(
thread_buffer_desc_m_k
),
1
,
ThreadBufferLengths_M_K
,
true
>
(
ThreadBufferDimAccessOrder
,
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
@@ -204,9 +212,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -204,9 +212,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// Copy x from Cache
// Copy x from Cache
// one pass: fwd, second pass: bwd
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
...
@@ -216,10 +221,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -216,10 +221,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
...
@@ -250,11 +255,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -250,11 +255,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
});
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
...
@@ -268,10 +272,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -268,10 +272,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_thread_buf
);
x_thread_buf
);
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -279,8 +283,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -279,8 +283,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
...
@@ -288,14 +290,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -288,14 +290,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// gamma
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
beta_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -303,11 +305,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -303,11 +305,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
...
@@ -318,8 +318,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -318,8 +318,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_global_val_buf
);
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment