Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d0f355a3
"example/36_permute/run_permute_example.inc" did not exist on "665b73ff1dd5c2b5fbdf2ddd9c19da2a583fdd2d"
Commit
d0f355a3
authored
Dec 19, 2023
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
55a89c74
b305a29e
Changes
81
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1409 additions
and
58 deletions
+1409
-58
cmake/ClangTidy.cmake
cmake/ClangTidy.cmake
+1
-1
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
...4_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
+6
-5
example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
...4_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
+3
-1
example/53_layernorm2d_bwd/CMakeLists.txt
example/53_layernorm2d_bwd/CMakeLists.txt
+1
-0
example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp
example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp
+80
-17
example/53_layernorm_bwd/CMakeLists.txt
example/53_layernorm_bwd/CMakeLists.txt
+0
-1
example/54_groupnorm_bwd/CMakeLists.txt
example/54_groupnorm_bwd/CMakeLists.txt
+1
-1
example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp
example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp
+87
-14
include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp
...or_operation/gpu/device/device_normalization_bwd_data.hpp
+59
-0
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp
...on/gpu/device/impl/device_normalization_bwd_data_impl.hpp
+465
-0
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
.../device/impl/device_normalization_bwd_gamma_beta_impl.hpp
+23
-9
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp
...eration/gpu/device/impl/device_normalization_fwd_impl.hpp
+2
-4
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
.../gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp
...pu/grid/normalization/gridwise_normalization_bwd_data.hpp
+554
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
...d/normalization/gridwise_normalization_bwd_gamma_beta.hpp
+10
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
...eference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
+25
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp
...eference_tensor_operation/cpu/reference_layernorm_bwd.hpp
+24
-0
library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp
...rary/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp
+64
-0
No files found.
cmake/ClangTidy.cmake
View file @
d0f355a3
...
@@ -149,7 +149,7 @@ function(clang_tidy_check TARGET)
...
@@ -149,7 +149,7 @@ function(clang_tidy_check TARGET)
add_custom_target
(
${
tidy_target
}
add_custom_target
(
${
tidy_target
}
# for some targets clang-tidy not able to get information from .clang-tidy
# for some targets clang-tidy not able to get information from .clang-tidy
DEPENDS
${
SOURCE
}
DEPENDS
${
SOURCE
}
COMMAND
${
CLANG_TIDY_COMMAND
}
"-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__
\;
__HIP_ROCclr__\}\]\}"
${
SOURCE
}
"-export-fixes=
${
CLANG_TIDY_FIXIT_DIR
}
/
${
TARGET
}
-
${
tidy_file
}
.yaml"
COMMAND
${
CLANG_TIDY_COMMAND
}
"-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__
\;
__HIP_PLATFORM_AMD__
\;
__HIP_ROCclr__\}\]\}"
${
SOURCE
}
"-export-fixes=
${
CLANG_TIDY_FIXIT_DIR
}
/
${
TARGET
}
-
${
tidy_file
}
.yaml"
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
COMMENT
"clang-tidy: Running clang-tidy on target
${
SOURCE
}
..."
COMMENT
"clang-tidy: Running clang-tidy on target
${
SOURCE
}
..."
)
)
...
...
docs/sphinx/requirements.in
View file @
d0f355a3
rocm-docs-core==0.30.
1
rocm-docs-core==0.30.
2
sphinxcontrib-bibtex==2.6.1
sphinxcontrib-bibtex==2.6.1
docs/sphinx/requirements.txt
View file @
d0f355a3
...
@@ -113,7 +113,7 @@ requests==2.31.0
...
@@ -113,7 +113,7 @@ requests==2.31.0
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==0.30.
1
rocm-docs-core==0.30.
2
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via
# via
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
View file @
d0f355a3
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
#include <random>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
...
@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
...
@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
for
(
std
::
size_t
n
=
0
;
n
<
N
;
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
N
;
++
n
)
{
{
ADataType
tmp_val
;
ADataType
tmp_val
;
// auto a_val = A_nchw(n, c, h, w);
auto
a_val
=
A_nchw
.
mData
[(
n
)
+
(
c
*
N
)
+
(
h
*
C
*
N
)
+
(
w
*
H
*
C
*
N
)];
auto
a_val
=
A_nchw
.
mData
[(
n
)
+
(
c
*
N
)
+
(
h
*
C
*
N
)
+
(
w
*
H
*
C
*
N
)];
functor_b
(
tmp_val
,
a_val
);
functor_b
(
tmp_val
,
a_val
);
// functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
functor_a
(
B_nhwc
.
mData
[(
n
)
+
(
c
*
W
*
H
*
N
)
+
(
h
*
N
)
+
(
w
*
H
*
N
)],
functor_a
(
B_nhwc
.
mData
[(
n
)
+
(
c
*
W
*
H
*
N
)
+
(
h
*
N
)
+
(
w
*
H
*
N
)],
scale
*
tmp_val
);
scale
*
tmp_val
);
}
}
...
@@ -62,12 +61,14 @@ int main()
...
@@ -62,12 +61,14 @@ int main()
bool
do_verification
=
true
;
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
std
::
vector
<
std
::
size_t
>
nchw
=
{
4
,
2
,
1
,
8
};
std
::
vector
<
std
::
size_t
>
nchw
=
{
16
,
8
,
32
,
64
};
std
::
vector
<
std
::
size_t
>
nhwc
=
{
4
,
1
,
8
,
2
};
std
::
vector
<
std
::
size_t
>
nhwc
=
{
16
,
32
,
64
,
8
};
Tensor
<
ADataType
>
a
(
nchw
);
Tensor
<
ADataType
>
a
(
nchw
);
Tensor
<
BDataType
>
b
(
nhwc
);
Tensor
<
BDataType
>
b
(
nhwc
);
float
scale
=
1.
f
;
float
scale
=
1.
f
;
auto
i
=
0
;
auto
i
=
0
;
std
::
mt19937
gen
(
11939
);
std
::
uniform_int_distribution
<
int
>
dis
(
0
,
1
);
for
(
std
::
size_t
w
=
0
;
w
<
a
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
w
=
0
;
w
<
a
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
h
=
0
;
h
<
a
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
h
=
0
;
h
<
a
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
c
=
0
;
c
<
a
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
a
.
mDesc
.
GetLengths
()[
1
];
++
c
)
...
@@ -75,7 +76,7 @@ int main()
...
@@ -75,7 +76,7 @@ int main()
{
{
a
.
mData
[(
n
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
(
c
*
nchw
[
2
]
*
nchw
[
3
])
+
a
.
mData
[(
n
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
(
c
*
nchw
[
2
]
*
nchw
[
3
])
+
(
h
*
nchw
[
3
])
+
w
]
=
i
;
(
h
*
nchw
[
3
])
+
w
]
=
i
;
i
++
;
i
=
dis
(
gen
)
;
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
View file @
d0f355a3
...
@@ -67,6 +67,8 @@ int main()
...
@@ -67,6 +67,8 @@ int main()
float
scale
=
1.
f
;
float
scale
=
1.
f
;
auto
i
=
0
;
auto
i
=
0
;
std
::
mt19937
gen
(
11939
);
std
::
uniform_int_distribution
<
int
>
dis
(
0
,
1
);
for
(
std
::
size_t
w
=
0
;
w
<
a
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
w
=
0
;
w
<
a
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
h
=
0
;
h
<
a
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
h
=
0
;
h
<
a
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
c
=
0
;
c
<
a
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
a
.
mDesc
.
GetLengths
()[
1
];
++
c
)
...
@@ -74,7 +76,7 @@ int main()
...
@@ -74,7 +76,7 @@ int main()
{
{
a
.
mData
[(
n
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
(
c
*
nchw
[
2
]
*
nchw
[
3
])
+
a
.
mData
[(
n
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
(
c
*
nchw
[
2
]
*
nchw
[
3
])
+
(
h
*
nchw
[
3
])
+
w
]
=
i
;
(
h
*
nchw
[
3
])
+
w
]
=
i
;
i
++
;
i
=
dis
(
gen
)
;
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
...
...
example/53_layernorm2d_bwd/CMakeLists.txt
0 → 100644
View file @
d0f355a3
add_example_executable
(
example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp
)
example/53_layernorm_bwd/layernorm2d_bwd_fp
16
.cpp
→
example/53_layernorm
2d
_bwd/layernorm2d_bwd_fp
32
.cpp
View file @
d0f355a3
...
@@ -15,16 +15,17 @@
...
@@ -15,16 +15,17 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
using
DYDataType
=
ck
::
half_
t
;
using
DYDataType
=
floa
t
;
using
XDataType
=
ck
::
half_
t
;
using
XDataType
=
floa
t
;
using
GammaDataType
=
ck
::
half_
t
;
using
GammaDataType
=
floa
t
;
using
MeanInvStdDataType
=
float
;
using
MeanInvStdDataType
=
float
;
using
DGammaDataType
=
ck
::
half_
t
;
using
DGammaDataType
=
floa
t
;
using
DBetaDataType
=
ck
::
half_
t
;
using
DBetaDataType
=
floa
t
;
using
DXDataType
=
ck
::
half_
t
;
using
DXDataType
=
floa
t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
...
@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1;
...
@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1;
// inv_std: [M, 1]
// inv_std: [M, 1]
// Output shape
// Output shape
// dx: [M, N]
// dgamma: [1, N]
// dgamma: [1, N]
// dbeta: [1, N]
// dbeta: [1, N]
...
@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1;
...
@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1;
// dbeta = reduce_sum(dy, axis=0)
// dbeta = reduce_sum(dy, axis=0)
// [CAUSION]
// [CAUSION]
// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension
// In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant
// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different
// dimension, K is reduced dimension Hence, M in this example and
// DeviceNormalizationBwdGammaBetaImpl is different
using
XDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdDataImpl
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// MThreadClusterSize
32
,
// KThreadClusterSize
1
,
// MThreadSliceSize
4
,
// KThreadSliceSize
true
,
// IsDYFastestDimReduced
4
,
// DYSrcVectorSize
true
,
// IsXFastestDimReduced
4
,
// XSrcVectorSize
true
,
// IsGammaFastestDimReduced
4
,
// GammaSrcVectorSize
false
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
true
,
// IsDXFastestDimReduced
4
>
;
// DXDstVectorSize
using
GammaBetaDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBetaImpl
<
using
GammaBetaDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBetaImpl
<
DYDataType
,
DYDataType
,
XDataType
,
XDataType
,
...
@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
...
@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
256
,
// BlockSize
256
,
// BlockSize
8
,
//
ClusterInvarient
8
,
//
MThreadClusterSize
32
,
// Cluster
Reduc
e
32
,
//
KThread
Cluster
Siz
e
8
,
//
SliceInvarient
4
,
//
MThreadSliceSize
1
,
//
SliceReduc
e
1
,
//
KThreadSliceSiz
e
false
,
// IsDYFastestDimReduced
false
,
// IsDYFastestDimReduced
8
,
// DYSrcVectorSize
4
,
// DYSrcVectorSize
false
,
// IsXFastestDimReduced
false
,
// IsXFastestDimReduced
8
,
// XSrcVectorSize
4
,
// XSrcVectorSize
true
,
// IsMeanInvStdFastestDimReduced
true
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
1
,
// MeanInvStdSrcVectorSize
1
,
// DGammaDstVectorSize
4
,
// DGammaDstVectorSize
1
>
;
// DBetaDstVectorSize
4
>
;
// DBetaDstVectorSize
int
main
()
int
main
()
{
{
...
@@ -96,16 +124,48 @@ int main()
...
@@ -96,16 +124,48 @@ int main()
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dx_dev
(
sizeof
(
DXDataType
)
*
dx
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
// backward x
auto
x_device_instance
=
XDeviceInstance
{};
auto
x_argument_ptr
=
x_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// lengths
{
N
,
1
},
// dyStrides
{
N
,
1
},
// xStrides
{
0
,
1
},
// gammaStrides
{
1
,
0
},
// meanStrides
{
1
,
0
},
// invStdStrides
{
N
,
1
},
// dxStrides
{
1
},
// reduceDims
dy_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
mean_dev
.
GetDeviceBuffer
(),
inv_std_dev
.
GetDeviceBuffer
(),
dx_dev
.
GetDeviceBuffer
());
if
(
!
x_device_instance
.
IsSupportedArgument
(
x_argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters are not supported."
<<
__FILE__
<<
":"
<<
__LINE__
<<
std
::
endl
;
return
1
;
};
auto
x_invoker_ptr
=
x_device_instance
.
MakeInvokerPointer
();
x_invoker_ptr
->
Run
(
x_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
// backward gamma & beta
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_argument_ptr
=
auto
gamma_beta_argument_ptr
=
gamma_beta_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// inLengths
gamma_beta_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// inLengths
...
@@ -126,7 +186,8 @@ int main()
...
@@ -126,7 +186,8 @@ int main()
if
(
!
gamma_beta_device_instance
.
IsSupportedArgument
(
gamma_beta_argument_ptr
.
get
()))
if
(
!
gamma_beta_device_instance
.
IsSupportedArgument
(
gamma_beta_argument_ptr
.
get
()))
{
{
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
std
::
cout
<<
"The runtime parameters are not supported."
<<
__FILE__
<<
":"
<<
__LINE__
<<
std
::
endl
;
return
1
;
return
1
;
};
};
...
@@ -156,9 +217,11 @@ int main()
...
@@ -156,9 +217,11 @@ int main()
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dx_dev
.
FromDevice
(
dx
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dx
,
host_dx
,
"Error: Incorrect dx"
,
1e-3
,
1e-3
);
}
}
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
...
...
example/53_layernorm_bwd/CMakeLists.txt
deleted
100644 → 0
View file @
55a89c74
add_example_executable
(
example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp
)
example/54_groupnorm_bwd/CMakeLists.txt
View file @
d0f355a3
add_example_executable
(
example_groupnorm_bwd_fp
16
groupnorm_bwd_fp
16
.cpp
)
add_example_executable
(
example_groupnorm_bwd_fp
32
groupnorm_bwd_fp
32
.cpp
)
example/54_groupnorm_bwd/groupnorm_bwd_fp
16
.cpp
→
example/54_groupnorm_bwd/groupnorm_bwd_fp
32
.cpp
View file @
d0f355a3
...
@@ -15,23 +15,58 @@
...
@@ -15,23 +15,58 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
using
DYDataType
=
ck
::
half_
t
;
using
DYDataType
=
floa
t
;
using
XDataType
=
ck
::
half_
t
;
using
XDataType
=
floa
t
;
using
GammaDataType
=
ck
::
half_
t
;
using
GammaDataType
=
floa
t
;
using
MeanInvStdDataType
=
float
;
using
MeanInvStdDataType
=
float
;
using
DGammaDataType
=
ck
::
half_
t
;
using
DGammaDataType
=
floa
t
;
using
DBetaDataType
=
ck
::
half_
t
;
using
DBetaDataType
=
floa
t
;
using
DXDataType
=
ck
::
half_
t
;
using
DXDataType
=
floa
t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
constexpr
int
Rank
=
5
;
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
constexpr
int
NumReduceDim
=
3
;
// Grouprnorm
// Grouprnorm
// kernel: M , K
// kernel 1: M , K
// dy: N, H, W, G, C -> N * G, H * W * C
// x: N, H, W, G, C -> N * G, H * W * C
// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C
// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// dx: N, H, W, G, C -> N * G, H * W * C
using
XDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdDataImpl
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// MThreadClusterSize
32
,
// KThreadClusterSize
1
,
// MThreadSliceSize
4
,
// KThreadSliceSize
true
,
// IsDYFastestDimReduced
4
,
// DYSrcVectorSize
true
,
// IsXFastestDimReduced
4
,
// XSrcVectorSize
true
,
// IsGammaFastestDimReduced
4
,
// GammaSrcVectorSize
false
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
true
,
// IsDXFastestDimReduced
4
>
;
// DXDstVectorSize
// kernel 2: M , K
// dy: N, H, W, G, C -> G * C, N * H * W
// dy: N, H, W, G, C -> G * C, N * H * W
// x: N, H, W, G, C -> G * C, N * H * W
// x: N, H, W, G, C -> G * C, N * H * W
// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1
// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1
...
@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
...
@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
256
,
// BlockSize
256
,
// BlockSize
8
,
// ClusterInvari
e
nt
8
,
// ClusterInvari
a
nt
32
,
// ClusterReduce
32
,
// ClusterReduce
8
,
// SliceInvari
e
nt
4
,
// SliceInvari
a
nt
1
,
// SliceReduce
1
,
// SliceReduce
false
,
// IsDYFastestDimReduced
false
,
// IsDYFastestDimReduced
8
,
// DYSrcVectorSize
4
,
// DYSrcVectorSize
false
,
// IsXFastestDimReduced
false
,
// IsXFastestDimReduced
8
,
// XSrcVectorSize
4
,
// XSrcVectorSize
false
,
// IsMeanInvStdFastestDimReduced
false
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
1
,
// MeanInvStdSrcVectorSize
1
,
// DGammaDstVectorSize
4
,
// DGammaDstVectorSize
1
>
;
// DBetaDstVectorSize
4
>
;
// DBetaDstVectorSize
int
main
()
int
main
()
{
{
...
@@ -93,20 +128,55 @@ int main()
...
@@ -93,20 +128,55 @@ int main()
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dx_dev
(
sizeof
(
DXDataType
)
*
dx
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
std
::
vector
<
ck
::
index_t
>
dyStrides
{
dy
.
mDesc
.
GetStrides
().
begin
(),
dy
.
mDesc
.
GetStrides
().
end
()};
std
::
vector
<
ck
::
index_t
>
dyStrides
{
dy
.
mDesc
.
GetStrides
().
begin
(),
dy
.
mDesc
.
GetStrides
().
end
()};
std
::
vector
<
ck
::
index_t
>
xStrides
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()};
std
::
vector
<
ck
::
index_t
>
xStrides
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()};
std
::
vector
<
ck
::
index_t
>
gammaStrides
=
{
0
,
0
,
0
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
meanStrides
=
{
G
,
0
,
0
,
1
,
0
};
std
::
vector
<
ck
::
index_t
>
meanStrides
=
{
G
,
0
,
0
,
1
,
0
};
std
::
vector
<
ck
::
index_t
>
invStdStrides
=
{
G
,
0
,
0
,
1
,
0
};
std
::
vector
<
ck
::
index_t
>
invStdStrides
=
{
G
,
0
,
0
,
1
,
0
};
std
::
vector
<
ck
::
index_t
>
dxStrides
{
dx
.
mDesc
.
GetStrides
().
begin
(),
dx
.
mDesc
.
GetStrides
().
end
()};
// backward x
auto
x_device_instance
=
XDeviceInstance
{};
auto
x_argument_ptr
=
x_device_instance
.
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
dyStrides
,
// dyStrides
xStrides
,
// xStrides
gammaStrides
,
// gammaStrides
meanStrides
,
// meanStrides
invStdStrides
,
// invStdStrides
dxStrides
,
// dxStrides
{
1
,
2
,
4
},
// reduceDims
dy_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
mean_dev
.
GetDeviceBuffer
(),
inv_std_dev
.
GetDeviceBuffer
(),
dx_dev
.
GetDeviceBuffer
());
if
(
!
x_device_instance
.
IsSupportedArgument
(
x_argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters are not supported."
<<
__FILE__
<<
":"
<<
__LINE__
<<
std
::
endl
;
return
1
;
};
auto
x_invoker_ptr
=
x_device_instance
.
MakeInvokerPointer
();
x_invoker_ptr
->
Run
(
x_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
// backward gamma & beta
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_argument_ptr
=
auto
gamma_beta_argument_ptr
=
...
@@ -128,7 +198,8 @@ int main()
...
@@ -128,7 +198,8 @@ int main()
if
(
!
gamma_beta_device_instance
.
IsSupportedArgument
(
gamma_beta_argument_ptr
.
get
()))
if
(
!
gamma_beta_device_instance
.
IsSupportedArgument
(
gamma_beta_argument_ptr
.
get
()))
{
{
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
std
::
cout
<<
"The runtime parameters are not supported."
<<
__FILE__
<<
":"
<<
__LINE__
<<
std
::
endl
;
return
1
;
return
1
;
};
};
...
@@ -158,9 +229,11 @@ int main()
...
@@ -158,9 +229,11 @@ int main()
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dx_dev
.
FromDevice
(
dx
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dx
,
host_dx
,
"Error: Incorrect dx"
,
1e-3
,
1e-3
);
}
}
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
...
...
include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp
0 → 100644
View file @
d0f355a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceNormalizationBwdData
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
void
*
p_dy
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_mean
,
const
void
*
p_invStd
,
void
*
p_dx
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceNormalizationBwdDataPtr
=
std
::
unique_ptr
<
DeviceNormalizationBwdData
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
Rank
,
NumReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp
0 → 100644
View file @
d0f355a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is Invariant dimension, K is reduced dimension
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseNormalizationBwd
,
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
typename
GridDesc_M_K
>
__global__
void
kernel_normalization_bwd_data
(
const
GridDesc_M_K
dy_grid_desc_m_k
,
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
mean_grid_desc_m_k
,
const
GridDesc_M_K
inv_std_grid_desc_m_k
,
const
GridDesc_M_K
dx_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DXDataType
*
const
__restrict__
p_dx_global
)
{
GridwiseNormalizationBwd
::
Run
(
dy_grid_desc_m_k
,
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
mean_grid_desc_m_k
,
inv_std_grid_desc_m_k
,
dx_grid_desc_m_k
,
num_k_block_tile_iteration
,
p_dy_global
,
p_x_global
,
p_gamma_global
,
p_mean_global
,
p_inv_std_global
,
p_dx_global
);
};
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
bool
IsDYFastestDimReduced
,
index_t
DYSrcVectorSize
,
bool
IsXFastestDimReduced
,
index_t
XSrcVectorSize
,
bool
IsGammaFastestDimReduced
,
index_t
GammaSrcVectorSize
,
bool
IsMeanInvStdFastestDimReduced
,
index_t
MeanInvStdSrcVectorSize
,
bool
IsDxFastestDimReduced
,
index_t
DXDstVectorSize
>
struct
DeviceNormalizationBwdDataImpl
:
public
DeviceNormalizationBwdData
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
Rank
,
NumReduceDim
>
{
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
GammaSrcVectorDim
=
IsGammaFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
DXDstVectorDim
=
IsDxFastestDimReduced
?
1
:
0
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
%
DYSrcVectorSize
==
0
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
%
DYSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"
);
static_assert
(((
DXDstVectorDim
==
0
&&
MThreadSliceSize
%
DXDstVectorSize
==
0
)
||
(
DXDstVectorDim
==
1
&&
KThreadSliceSize
%
DXDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
static
auto
Make2dDescriptor
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
,
int
numBlockTileIteration
)
{
const
auto
tupleLengths
=
make_tuple_from_array
(
lengths
,
Number
<
Rank
>
{});
const
auto
tupleStrides
=
make_tuple_from_array
(
strides
,
Number
<
Rank
>
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
pad_K
=
K_BlockTileSize
*
numBlockTileIteration
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
),
make_right_pad_transform
(
reduceLength
,
pad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
grid_desc_m_k_padded
;
}
using
GridDesc_M_K
=
decltype
(
Make2dDescriptor
({
1
},
{
1
},
1
));
using
GridwiseNormalizationBwdDataGeneric
=
GridwiseNormalizationBwdData_mk_to_mk
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
DYSrcVectorDim
,
DYSrcVectorSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
DXDstVectorDim
,
DXDstVectorSize
,
false
>
;
using
GridwiseNormalizationBwdDataSweepOnce
=
GridwiseNormalizationBwdData_mk_to_mk
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
DYSrcVectorDim
,
DYSrcVectorSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
DXDstVectorDim
,
DXDstVectorSize
,
true
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
DYDataType
*
p_dy
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
MeanInvStdDataType
*
p_mean
,
const
MeanInvStdDataType
*
p_invStd
,
DXDataType
*
p_dx
)
:
p_dy_
(
p_dy
),
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_mean_
(
p_mean
),
p_invStd_
(
p_invStd
),
p_dx_
(
p_dx
)
{
lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
dyStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
dyStrides
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
meanStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
meanStrides
,
reduceDims
);
invStdStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
invStdStrides
,
reduceDims
);
dxStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
dxStrides
,
reduceDims
);
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
lengths_
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
);
dy_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dyStrides_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
gammaStrides_
,
numBlockTileIteration_
);
mean_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
meanStrides_
,
numBlockTileIteration_
);
inv_std_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
invStdStrides_
,
numBlockTileIteration_
);
dx_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dxStrides_
,
numBlockTileIteration_
);
isSweeponce_
=
dy_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
K_BlockTileSize
;
}
const
DYDataType
*
p_dy_
;
const
XDataType
*
p_x_
;
const
GammaDataType
*
p_gamma_
;
const
MeanInvStdDataType
*
p_mean_
;
const
MeanInvStdDataType
*
p_invStd_
;
DXDataType
*
p_dx_
;
std
::
vector
<
index_t
>
lengths_
;
std
::
vector
<
index_t
>
dyStrides_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
meanStrides_
;
std
::
vector
<
index_t
>
invStdStrides_
;
std
::
vector
<
index_t
>
dxStrides_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
// tensor descriptor
GridDesc_M_K
dy_grid_desc_m_k_
;
GridDesc_M_K
x_grid_desc_m_k_
;
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
mean_grid_desc_m_k_
;
GridDesc_M_K
inv_std_grid_desc_m_k_
;
GridDesc_M_K
dx_grid_desc_m_k_
;
bool
isSweeponce_
;
index_t
MRaw_
;
// Invariant length
index_t
KRaw_
;
// reduce length
};
struct
Invoker
:
public
BaseInvoker
{
auto
KernelSelector
(
bool
isSweepOnce
)
{
return
isSweepOnce
?
kernel_normalization_bwd_data
<
GridwiseNormalizationBwdDataSweepOnce
,
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
GridDesc_M_K
>
:
kernel_normalization_bwd_data
<
GridwiseNormalizationBwdDataGeneric
,
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
GridDesc_M_K
>
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel_main
=
KernelSelector
(
arg
.
isSweeponce_
);
return
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
dy_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
mean_grid_desc_m_k_
,
arg
.
inv_std_grid_desc_m_k_
,
arg
.
dx_grid_desc_m_k_
,
arg
.
numBlockTileIteration_
,
arg
.
p_dy_
,
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_mean_
,
arg
.
p_invStd_
,
arg
.
p_dx_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
template
<
index_t
SrcVectorDim
,
index_t
SrcVectorSize
>
bool
IsVectorDimSizeValid
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
if
constexpr
(
SrcVectorSize
==
1
)
return
true
;
// Fastest dimension is not reduced
if
constexpr
(
SrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
return
false
;
if
(
strides
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
NumInvariantDim
-
1
]
%
SrcVectorSize
!=
0
)
return
false
;
}
else
// Fastest dimension is reduced
{
if
(
strides
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
Rank
-
1
]
%
SrcVectorSize
!=
0
)
return
false
;
};
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
bool
pass
=
true
;
pass
&=
IsVectorDimSizeValid
<
DYSrcVectorDim
,
DYSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
dyStrides_
);
pass
&=
IsVectorDimSizeValid
<
XSrcVectorDim
,
XSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
xStrides_
);
pass
&=
IsVectorDimSizeValid
<
GammaSrcVectorDim
,
GammaSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
gammaStrides_
);
pass
&=
IsVectorDimSizeValid
<
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
meanStrides_
);
pass
&=
IsVectorDimSizeValid
<
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
invStdStrides_
);
pass
&=
IsVectorDimSizeValid
<
DXDstVectorDim
,
DXDstVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
dxStrides_
);
return
pass
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
void
*
p_dy
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_mean
,
const
void
*
p_invStd
,
void
*
p_dx
)
override
{
if
(
lengths
.
size
()
!=
Rank
||
dyStrides
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
meanStrides
.
size
()
!=
Rank
||
invStdStrides
.
size
()
!=
Rank
||
dxStrides
.
size
()
!=
Rank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
lengths
,
dyStrides
,
xStrides
,
gammaStrides
,
meanStrides
,
invStdStrides
,
dxStrides
,
reduceDims
,
static_cast
<
const
DYDataType
*>
(
p_dy
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
MeanInvStdDataType
*>
(
p_mean
),
static_cast
<
const
MeanInvStdDataType
*>
(
p_invStd
),
static_cast
<
DXDataType
*>
(
p_dx
));
}
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationBwdDataImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"DYSrcVectorSize"
<<
DYSrcVectorSize
<<
"_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_MeanRstd"
<<
MeanInvStdSrcVectorSize
<<
"_Dx"
<<
DXDstVectorSize
;
str
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
View file @
d0f355a3
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is
i
nvari
e
nt dimension, K is reduced dimension
// M is
I
nvari
a
nt dimension, K is reduced dimension
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
Rank
,
Rank
,
NumReduceDim
>
NumReduceDim
>
{
{
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
...
@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static_assert
(
static_assert
(
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"
);
"check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
GridDesc_M
dgamma_grid_desc_m_
;
GridDesc_M
dgamma_grid_desc_m_
;
GridDesc_M
dbeta_grid_desc_m_
;
GridDesc_M
dbeta_grid_desc_m_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
};
};
...
@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
{
{
return
std
::
make_unique
<
Invoker
>
();
return
std
::
make_unique
<
Invoker
>
();
}
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationBwdGammaBetaImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"VectorSize_DY"
<<
DYSrcVectorSize
<<
"_X"
<<
XSrcVectorSize
;
str
<<
"_DGamma"
<<
DGammaDstVectorSize
<<
"_DBeta"
<<
DBetaDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp
View file @
d0f355a3
...
@@ -19,7 +19,7 @@ namespace tensor_operation {
...
@@ -19,7 +19,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// Y = Normalization(X, Beta, Gamma)
// M: Invari
e
nt length
// M: Invari
a
nt length
// K: Reduce length (Calculate mean and variance along K dimension)
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
// Then, M = N, K = C * H * W
...
@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
...
@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
GridDesc_M
save_inv_std_grid_desc_m_
;
GridDesc_M
save_inv_std_grid_desc_m_
;
bool
isSweeponce_
;
bool
isSweeponce_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
index_t
invariant_lowest_length_
;
...
@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
...
@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
}
}
else
else
{
{
printf
(
"!!!! %d
\n
"
,
p_arg_
->
invariant_lowest_length_
);
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
View file @
d0f355a3
...
@@ -108,7 +108,7 @@ namespace tensor_operation {
...
@@ -108,7 +108,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// Y = Normalization(X, Beta, Gamma)
// M: Invari
e
nt length
// M: Invari
a
nt length
// K: Reduce length (Calculate mean and variance along K dimension)
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
// Then, M = N, K = C * H * W
...
@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
...
@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
index_t
invariant_lowest_length_
;
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp
0 → 100644
View file @
d0f355a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace
ck
{
// Tensor Shape
// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1]
// Flow:
// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size
// c = -b * x_mean - db * inv_std / reduce_size
// dx = inv_std * dy * gamma + b * x + c
// return dx
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DXDataType
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
DYSrcVectorDim
,
index_t
DYSrcVectorSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
MeanInvStdSrcVectorDim
,
index_t
MeanInvStdSrcVectorSize
,
index_t
DXDstVectorDim
,
index_t
DXDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseNormalizationBwdData_mk_to_mk
{
// if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
==
XSrcVectorSize
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
==
GammaSrcVectorSize
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
==
GammaSrcVectorSize
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
((
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
==
MeanInvStdSrcVectorSize
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
==
MeanInvStdSrcVectorSize
)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"
);
static_assert
(((
DXDstVectorDim
==
0
&&
MThreadSliceSize
==
DXDstVectorSize
)
||
(
DXDstVectorDim
==
1
&&
KThreadSliceSize
==
DXDstVectorSize
)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
typename
conditional
<
DYSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
XThreadBufferDimAccessOrder
=
typename
conditional
<
XSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
GammaThreadBufferDimAccessOrder
=
typename
conditional
<
GammaSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
MeanInvStdThreadBufferDimAccessOrder
=
typename
conditional
<
MeanInvStdSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
DXThreadBufferDimAccessOrder
=
typename
conditional
<
DXDstVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
DYThreadBufferDimAccessOrder
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
true
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
dy_grid_desc_m_k
,
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
mean_grid_desc_m_k
,
const
GridDesc_M_K
&
inv_std_grid_desc_m_k
,
const
GridDesc_M_K
&
dx_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DXDataType
*
const
__restrict__
p_dx_global
)
{
// LDS
__shared__
ComputeDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
// Global
const
auto
dy_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy_global
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_inv_std_global
,
inv_std_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dx_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dx_global
,
dx_grid_desc_m_k
.
GetElementSpaceSize
());
// VGPR
auto
dy_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
x_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
gamma_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
mean_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
inv_std_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
dx_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
ds_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
auto
db_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
// thread id
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
// IO
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DYDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
DYThreadBufferDimAccessOrder
,
DYSrcVectorDim
,
DYSrcVectorSize
,
1
,
false
>
(
dy_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
XThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
false
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
XThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
false
>
(
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_mean_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
false
>
(
mean_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_inv_std_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
false
>
(
inv_std_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
DXDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths_M_K
,
DXThreadBufferDimAccessOrder
,
DXDstVectorDim
,
DXDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
dx_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
ComputeDataType
reduce_size
=
type_convert
<
ComputeDataType
>
(
dy_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
]);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
ds_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
db_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
// Separate sweep once and sweep twice pipeline
// Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K
// we don't need to use loop to read x, dy, gamma twice
if
constexpr
(
SweepOnce
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
ds_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
x_thread_buf
[
offset_m_k
];
db_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
];
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
ds_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
db_thread_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType
b
=
db_thread_buf
[
offset_m
]
*
mean_thread_buf
[
offset_m_k
]
-
ds_thread_buf
[
offset_m
];
b
*=
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
ComputeDataType
c
=
-
b
*
mean_thread_buf
(
offset_m_k
);
c
-=
db_thread_buf
[
offset_m
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
dx_thread_buf
(
offset_m_k
)
=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
+
b
*
x_thread_buf
[
offset_m_k
]
+
c
;
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
}
// end of sweep once
else
// Sweep Twice pipeline
{
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
ds_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
x_thread_buf
[
offset_m_k
];
db_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
];
});
});
}
// end of first sweep
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
ds_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
db_thread_buf
(
I
));
});
// reverse read for using dy, gamma and x in the cache
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
// move to tail
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
// move from start to tail
threadwise_mean_load
.
MoveSrcSliceWindow
(
mean_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_inv_std_load
.
MoveSrcSliceWindow
(
inv_std_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType
b
=
db_thread_buf
[
offset_m
]
*
mean_thread_buf
[
offset_m_k
]
-
ds_thread_buf
[
offset_m
];
b
*=
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
ComputeDataType
c
=
-
b
*
mean_thread_buf
(
offset_m_k
);
c
-=
db_thread_buf
[
offset_m
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
dx_thread_buf
(
offset_m_k
)
=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
+
b
*
x_thread_buf
[
offset_m_k
]
+
c
;
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_mean_load
.
MoveSrcSliceWindow
(
mean_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_inv_std_load
.
MoveSrcSliceWindow
(
inv_std_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
View file @
d0f355a3
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t
DBetaDstVectorSize
>
index_t
DBetaDstVectorSize
>
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
{
{
// if we just check ThreadSliceSize
&
VectorSize == 0, the performance may be poor
// if we just check ThreadSliceSize
%
VectorSize == 0, the performance may be poor
(coalesce)
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
...
@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
...
@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
// do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm
static_assert
(
((
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
==
DGammaDstVectorSize
&&
MThreadSliceSize
==
DBetaDstVectorSize
,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
using
DYThreadBufferDimAccessOrder
=
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
View file @
d0f355a3
...
@@ -16,6 +16,31 @@ namespace ck {
...
@@ -16,6 +16,31 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def groupnorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [N, H, W, G, C], gamma = [1, 1, 1, G, C], x_mean, rstd = [N, 1, 1, G, 1]
// N, H, W, G, C = x.shape
// dx = normalization_input_backward(
// dy, x, gamma, x_mean, rstd, (1, 2, 4), H * W * C)
// dgamma, dbeta = normalization_gamma_beta_backward(
// dy, x, x_mean, rstd, (0, 1, 2))
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L655
template
<
typename
DYDataType
,
template
<
typename
DYDataType
,
typename
XDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp
View file @
d0f355a3
...
@@ -16,6 +16,30 @@ namespace ck {
...
@@ -16,6 +16,30 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_beta_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def layernorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [M, K], gamma = [1, K], x_mean, rstd = [M, 1]
// # dx = [M, K], dgamma, dbeta = [1, K]
// M, K = x.shape
// dx = normalization_input_backward(dy, x, gamma, x_mean, rstd, 1, K)
// dgamma, dbeta = normalization_gamma_beta_backward(dy, x, x_mean, rstd, 0)
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/layer_norm_kernel.cpp#L196
template
<
typename
DYDataType
,
template
<
typename
DYDataType
,
typename
XDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp
0 → 100644
View file @
d0f355a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_groupnorm_bwd_data_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalizationBwdData
<
F32
,
F32
,
F32
,
F32
,
F32
,
5
,
3
>>>&
);
#endif
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdData
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
5
,
3
>>
{
using
DeviceOp
=
DeviceNormalizationBwdData
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
5
,
3
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
DYDataType
,
F32
>
&&
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
MeanInvStdDataType
,
F32
>
&&
is_same_v
<
DXDataType
,
F32
>
)
{
add_device_groupnorm_bwd_data_f32_instances
(
op_ptrs
);
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment