Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6ed9ab3a
Commit
6ed9ab3a
authored
Jul 05, 2022
by
rocking
Browse files
Use 1d descriptor for gamma and beta
parent
3bb0cbe7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
178 additions
and
154 deletions
+178
-154
example/24_layernorm/layernorm_blockwise.cpp
example/24_layernorm/layernorm_blockwise.cpp
+12
-13
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
+69
-51
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
+97
-90
No files found.
example/24_layernorm/layernorm_blockwise.cpp
View file @
6ed9ab3a
...
@@ -41,9 +41,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
...
@@ -41,9 +41,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
8
,
// SliceK
8
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SrcScalarPerVector
8
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
,
// BetaScalarPerVector
1
>
;
// OutScalarPerVector
1
>
;
// OutScalarPerVector
...
@@ -97,7 +95,7 @@ int main()
...
@@ -97,7 +95,7 @@ int main()
ck
::
index_t
M
=
1024
;
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
1024
;
ck
::
index_t
Stride
=
N
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
...
@@ -128,16 +126,17 @@ int main()
...
@@ -128,16 +126,17 @@ int main()
beta_dev
.
ToDevice
(
beta
.
mData
.
data
());
beta_dev
.
ToDevice
(
beta
.
mData
.
data
());
auto
device_instance
=
DeviceInstance
{};
auto
device_instance
=
DeviceInstance
{};
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
({
M
,
N
},
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
{
Stride
,
1
},
{
M
,
N
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
gamma
.
mDesc
.
GetStrides
().
begin
(),
gamma
.
mDesc
.
GetStrides
().
end
()},
{
1
},
std
::
vector
<
ck
::
index_t
>
{
beta
.
mDesc
.
GetStrides
().
begin
(),
beta
.
mDesc
.
GetStrides
().
end
()},
1e-4
,
{
1
},
x_dev
.
GetDeviceBuffer
(),
1e-4
,
gamma_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
());
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
());
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
...
...
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
View file @
6ed9ab3a
...
@@ -34,21 +34,17 @@ template <typename XDataType,
...
@@ -34,21 +34,17 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
>
struct
DeviceLayernorm
:
public
BaseOperator
struct
DeviceLayernorm
:
public
BaseOperator
{
{
static_assert
(
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
KThreadSliceSize
%
GammaSrcVectorSize
==
0
),
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
static_assert
(
((
BetaSrcVectorDim
==
0
&&
MThreadSliceSize
%
BetaSrcVectorSize
==
0
)
||
(
KThreadSliceSize
%
BetaSrcVectorSize
==
0
),
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -75,7 +71,38 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -75,7 +71,38 @@ struct DeviceLayernorm : public BaseOperator
XSrcVectorSize
,
XSrcVectorSize
,
1
>
;
// YDstVectorSize
1
>
;
// YDstVectorSize
static
auto
MakeAffine1dDescriptor
(
const
std
::
vector
<
index_t
>&
Lengths
,
const
std
::
vector
<
index_t
>&
Strides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleLengths
=
make_tuple_from_array
(
Lengths
,
Number
<
NumReduceDim
>
{});
const
auto
tupleStrides
=
make_tuple_from_array
(
Strides
,
Number
<
NumReduceDim
>
{});
auto
desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_k
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumReduceDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
reduceTotalLength
=
grid_desc_k
.
GetLength
(
Number
<
0
>
{});
const
int
reduceSizePerBlock
=
Reduction
::
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
Pad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceTotalLength
;
auto
grid_desc_k_padded
=
transform_tensor_descriptor
(
grid_desc_k
,
make_tuple
(
make_right_pad_transform
(
reduceTotalLength
,
Pad_K
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_K
=
decltype
(
MakeAffine1dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
GridwiseLayernorm_mk_to_mk
<
XDataType
,
using
GridwiseReduceLayernormGeneric
=
GridwiseLayernorm_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
...
@@ -83,6 +110,7 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -83,6 +110,7 @@ struct DeviceLayernorm : public BaseOperator
YDataType
,
YDataType
,
AccDataType
,
AccDataType
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -90,9 +118,7 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -90,9 +118,7 @@ struct DeviceLayernorm : public BaseOperator
KThreadSliceSize
,
KThreadSliceSize
,
XSrcVectorDim
,
XSrcVectorDim
,
XSrcVectorSize
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorSize
,
YDstVectorSize
,
false
>
;
false
>
;
...
@@ -103,6 +129,7 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -103,6 +129,7 @@ struct DeviceLayernorm : public BaseOperator
YDataType
,
YDataType
,
AccDataType
,
AccDataType
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -110,9 +137,7 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -110,9 +137,7 @@ struct DeviceLayernorm : public BaseOperator
KThreadSliceSize
,
KThreadSliceSize
,
XSrcVectorDim
,
XSrcVectorDim
,
XSrcVectorSize
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
YDstVectorSize
,
YDstVectorSize
,
true
>
;
true
>
;
...
@@ -144,16 +169,22 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -144,16 +169,22 @@ struct DeviceLayernorm : public BaseOperator
PassThrough
{}),
PassThrough
{}),
epsilon_
(
epsilon
),
epsilon_
(
epsilon
),
p_gamma_
(
p_gamma
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
)
p_beta_
(
p_beta
),
gammaStrides_
(
gammaStrides
),
betaStrides_
(
betaStrides
)
{
{
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
reduceLength_
.
resize
(
NumReduceDim
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
for
(
int
i
=
0
;
i
<
NumReduceDim
;
++
i
)
{
reduceLength_
[
i
]
=
lengths
[
reduceDims
[
i
]];
}
}
}
AccDataType
epsilon_
;
AccDataType
epsilon_
;
const
GammaDataType
*
p_gamma_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
const
BetaDataType
*
p_beta_
;
std
::
vector
<
index_t
>
reduceLength_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
};
};
...
@@ -164,10 +195,10 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -164,10 +195,10 @@ struct DeviceLayernorm : public BaseOperator
{
{
const
auto
x_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
x_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
gamma_grid_desc_
m_
k
=
Reduction
::
MakeSrc2
dDescriptor
(
const
auto
gamma_grid_desc_k
=
MakeAffine1
dDescriptor
(
arg
.
in
Length
s
_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
reduce
Length_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
beta_grid_desc_
m_
k
=
Reduction
::
MakeSrc2
dDescriptor
(
const
auto
beta_grid_desc_k
=
MakeAffine1
dDescriptor
(
arg
.
in
Length
s
_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
reduce
Length_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
y_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
y_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
...
@@ -180,14 +211,16 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -180,14 +211,16 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
AccDataType
,
AccDataType
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_K
>
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
AccDataType
,
AccDataType
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_K
>
;
float
avg_time
=
0
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
@@ -196,8 +229,8 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -196,8 +229,8 @@ struct DeviceLayernorm : public BaseOperator
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
x_grid_desc_m_k
,
x_grid_desc_m_k
,
gamma_grid_desc_
m_
k
,
gamma_grid_desc_k
,
beta_grid_desc_
m_
k
,
beta_grid_desc_k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
arg
.
numBlockTileIteration
,
arg
.
numBlockTileIteration
,
arg
.
epsilon_
,
arg
.
epsilon_
,
...
@@ -230,41 +263,26 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -230,41 +263,26 @@ struct DeviceLayernorm : public BaseOperator
return
false
;
return
false
;
}
}
// if fastest dim is not reduced
if
(
p_arg_
->
gammaStrides_
.
size
()
!=
NumReduceDim
||
if
constexpr
(
GammaSrcVectorDim
==
0
)
p_arg_
->
betaStrides_
.
size
()
!=
NumReduceDim
)
{
return
false
;
if
(
p_arg_
->
gammaStrides_
[
Reduction
::
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
GammaSrcVectorSize
!=
0
)
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
return
(
false
);
bool
ret
=
true
;
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
reduce_lowest_length
%
GammaSrcVectorSize
!=
0
)
if
(
!
isLastDimensionCoalesced
)
return
(
false
);
ret
=
scalarPerVector
==
1
;
}
else
ret
=
KThreadSliceSize
%
scalarPerVector
==
0
;
// if fastest dim is not reduced
return
ret
;
if
constexpr
(
BetaSrcVectorDim
==
0
)
};
{
if
(
p_arg_
->
betaStrides_
[
Reduction
::
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
gammaStrides_
.
back
()
==
1
,
GammaSrcVectorSize
))
return
(
false
);
return
false
;
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
reduce_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
betaStrides_
.
back
()
==
1
,
BetaSrcVectorSize
))
return
(
false
);
return
false
;
}
return
true
;
return
true
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
View file @
6ed9ab3a
...
@@ -20,10 +20,11 @@ template <typename GridwiseReduction,
...
@@ -20,10 +20,11 @@ template <typename GridwiseReduction,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
>
typename
GridDesc_M_K
,
typename
GridDesc_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_
M_
K
gamma_grid_desc_
m_
k
,
const
GridDesc_K
gamma_grid_desc_k
,
const
GridDesc_
M_
K
beta_grid_desc_
m_
k
,
const
GridDesc_K
beta_grid_desc_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -33,8 +34,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
...
@@ -33,8 +34,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
YDataType
*
const
__restrict__
p_y_global
)
YDataType
*
const
__restrict__
p_y_global
)
{
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_
m_
k
,
gamma_grid_desc_k
,
beta_grid_desc_
m_
k
,
beta_grid_desc_k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
epsilon
,
epsilon
,
...
@@ -50,6 +51,7 @@ template <typename XDataType,
...
@@ -50,6 +51,7 @@ template <typename XDataType,
typename
YDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -57,9 +59,7 @@ template <typename XDataType,
...
@@ -57,9 +59,7 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
bool
SweepOnce
>
bool
SweepOnce
>
...
@@ -114,8 +114,8 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -114,8 +114,8 @@ struct GridwiseLayernorm_mk_to_mk
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -141,11 +141,9 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -141,11 +141,9 @@ struct GridwiseLayernorm_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
beta_thread_buf
;
beta_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
y_thread_buf
;
...
@@ -175,15 +173,18 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -175,15 +173,18 @@ struct GridwiseLayernorm_mk_to_mk
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
AccDataType
,
GridDesc_M_K
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
decltype
(
thread_buffer_desc
_m_k
),
ThreadBufferLengths
,
ThreadBufferLengths
_M_K
,
ThreadBufferDimAccessOrder
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorDim
,
XSrcVectorSize
,
XSrcVectorSize
,
...
@@ -194,67 +195,68 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -194,67 +195,68 @@ struct GridwiseLayernorm_mk_to_mk
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
auto
threadwise_gamma_load
=
AccDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
GridDesc_M_K
,
AccDataType
,
decltype
(
thread_buffer_desc
),
GridDesc_K
,
ThreadBufferLengths
,
decltype
(
thread_buffer_desc_k
),
ThreadBufferDimAccessOrder
,
ThreadBufferLengths_K
,
GammaSrcVectorDim
,
Sequence
<
0
>
,
GammaSrcVectorSize
,
0
,
1
,
GammaSrcVectorSize
,
true
>
(
1
,
gamma_grid_desc_m_k
,
true
>
(
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
AccDataType
,
GridDesc_
M_
K
,
GridDesc_K
,
decltype
(
thread_buffer_desc
),
decltype
(
thread_buffer_desc
_k
),
ThreadBufferLengths
,
ThreadBufferLengths
_K
,
ThreadBufferDimAccessOrder
,
Sequence
<
0
>
,
BetaSrcVectorDim
,
0
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
beta_grid_desc_m_k
,
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
auto
threadwise_y_store
=
thread_k_cluster_id
*
KThreadSliceSize
));
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
decltype
(
thread_buffer_desc_m_k
),
YDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
PassThroughOp
,
GridDesc_M_K
,
ThreadBufferLengths_M_K
,
PassThroughOp
,
ThreadBufferDimAccessOrder
,
ThreadBufferLengths
,
XSrcVectorDim
,
ThreadBufferDimAccessOrder
,
YDstVectorSize
,
XSrcVectorDim
,
InMemoryDataOperationEnum
::
Set
,
YDstVectorSize
,
1
,
InMemoryDataOperationEnum
::
Set
,
true
>
(
1
,
y_grid_desc_m_k
,
true
>
(
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
y_grid_desc_m_k
,
thread_m_cluster_id
*
MThreadSliceSize
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_k_cluster_id
*
KThreadSliceSize
),
thread_m_cluster_id
*
MThreadSliceSize
,
PassThroughOp
{});
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
// Copy x from Cache
// Copy x from Cache
// one pass: fwd, second pass: bwd
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
// E(x), E[x^2], var(x)
// E(x), E[x^2], var(x)
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
...
@@ -264,22 +266,23 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -264,22 +266,23 @@ struct GridwiseLayernorm_mk_to_mk
{
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
x_global_val_buf
,
thread_buffer_desc
,
thread_buffer_desc
_m_k
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
x_thread_buf
);
x_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_m_k
=
x_square_thread_buf
(
Number
<
offset
>
{})
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
x_thread_buf
(
Number
<
offset
>
{})
*
x_thread_buf
(
Number
<
offset
>
{});
x_square_thread_buf
(
Number
<
offset_m_k
>
{})
=
x_thread_buf
(
Number
<
offset_m_k
>
{})
*
x_thread_buf
(
Number
<
offset_m_k
>
{});
});
});
});
});
ThreadwiseSumReduce
::
Reduce
(
x_thread_buf
,
mean_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x_thread_buf
,
mean_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x_square_thread_buf
,
mean_square_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x_square_thread_buf
,
mean_square_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step
_m_k
);
++
reducedTiles
;
++
reducedTiles
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
...
@@ -297,12 +300,13 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -297,12 +300,13 @@ struct GridwiseLayernorm_mk_to_mk
});
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto
thread_copy_tail
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step
;
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step
_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
x
_grid_desc_
m_
k
,
thread_copy_tail
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma
_grid_desc_k
,
thread_copy_tail
_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
x
_grid_desc_
m_
k
,
thread_copy_tail
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta
_grid_desc_k
,
thread_copy_tail
_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail
_m_k
);
reducedTiles
=
0
;
reducedTiles
=
0
;
do
do
...
@@ -311,48 +315,51 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -311,48 +315,51 @@ struct GridwiseLayernorm_mk_to_mk
{
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
x_global_val_buf
,
thread_buffer_desc
,
thread_buffer_desc
_m_k
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
x_thread_buf
);
x_thread_buf
);
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
gamma_global_val_buf
,
gamma_global_val_buf
,
thread_buffer_desc
,
thread_buffer_desc
_k
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
),
gamma_thread_buf
);
gamma_thread_buf
);
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
beta_global_val_buf
,
beta_global_val_buf
,
thread_buffer_desc
,
thread_buffer_desc
_k
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
),
beta_thread_buf
);
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
// normalize
y_thread_buf
(
Number
<
offset
>
{})
=
y_thread_buf
(
Number
<
offset
_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset
>
{})
-
mean_thread_buf
(
iM
))
/
(
x_thread_buf
(
Number
<
offset
_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_value_buf
(
iM
)
+
epsilon
);
sqrt
(
var_value_buf
(
iM
)
+
epsilon
);
// affine
// affine
y_thread_buf
(
Number
<
offset
>
{})
=
y_thread_buf
(
Number
<
offset
_m_k
>
{})
=
y_thread_buf
(
Number
<
offset
>
{})
*
gamma_thread_buf
(
Number
<
offset
>
{})
+
y_thread_buf
(
Number
<
offset
_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset
_k
>
{})
+
beta_thread_buf
(
Number
<
offset
>
{});
beta_thread_buf
(
Number
<
offset
_k
>
{});
});
});
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc
,
threadwise_y_store
.
Run
(
thread_buffer_desc
_m_k
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
y_thread_buf
,
y_thread_buf
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
y_global_val_buf
);
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step
_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
x
_grid_desc_
m_
k
,
thread_copy_bwd_step
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma
_grid_desc_k
,
thread_copy_bwd_step
_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
x
_grid_desc_
m_
k
,
thread_copy_bwd_step
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta
_grid_desc_k
,
thread_copy_bwd_step
_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step
_m_k
);
++
reducedTiles
;
++
reducedTiles
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment