Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d0f355a3
Commit
d0f355a3
authored
Dec 19, 2023
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
55a89c74
b305a29e
Changes
81
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1409 additions
and
58 deletions
+1409
-58
cmake/ClangTidy.cmake
cmake/ClangTidy.cmake
+1
-1
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
...4_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
+6
-5
example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
...4_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
+3
-1
example/53_layernorm2d_bwd/CMakeLists.txt
example/53_layernorm2d_bwd/CMakeLists.txt
+1
-0
example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp
example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp
+80
-17
example/53_layernorm_bwd/CMakeLists.txt
example/53_layernorm_bwd/CMakeLists.txt
+0
-1
example/54_groupnorm_bwd/CMakeLists.txt
example/54_groupnorm_bwd/CMakeLists.txt
+1
-1
example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp
example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp
+87
-14
include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp
...or_operation/gpu/device/device_normalization_bwd_data.hpp
+59
-0
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp
...on/gpu/device/impl/device_normalization_bwd_data_impl.hpp
+465
-0
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
.../device/impl/device_normalization_bwd_gamma_beta_impl.hpp
+23
-9
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp
...eration/gpu/device/impl/device_normalization_fwd_impl.hpp
+2
-4
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
.../gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp
...pu/grid/normalization/gridwise_normalization_bwd_data.hpp
+554
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
...d/normalization/gridwise_normalization_bwd_gamma_beta.hpp
+10
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
...eference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
+25
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp
...eference_tensor_operation/cpu/reference_layernorm_bwd.hpp
+24
-0
library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp
...rary/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp
+64
-0
No files found.
cmake/ClangTidy.cmake
View file @
d0f355a3
...
@@ -149,7 +149,7 @@ function(clang_tidy_check TARGET)
...
@@ -149,7 +149,7 @@ function(clang_tidy_check TARGET)
add_custom_target
(
${
tidy_target
}
add_custom_target
(
${
tidy_target
}
# for some targets clang-tidy not able to get information from .clang-tidy
# for some targets clang-tidy not able to get information from .clang-tidy
DEPENDS
${
SOURCE
}
DEPENDS
${
SOURCE
}
COMMAND
${
CLANG_TIDY_COMMAND
}
"-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__
\;
__HIP_ROCclr__\}\]\}"
${
SOURCE
}
"-export-fixes=
${
CLANG_TIDY_FIXIT_DIR
}
/
${
TARGET
}
-
${
tidy_file
}
.yaml"
COMMAND
${
CLANG_TIDY_COMMAND
}
"-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__
\;
__HIP_PLATFORM_AMD__
\;
__HIP_ROCclr__\}\]\}"
${
SOURCE
}
"-export-fixes=
${
CLANG_TIDY_FIXIT_DIR
}
/
${
TARGET
}
-
${
tidy_file
}
.yaml"
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
COMMENT
"clang-tidy: Running clang-tidy on target
${
SOURCE
}
..."
COMMENT
"clang-tidy: Running clang-tidy on target
${
SOURCE
}
..."
)
)
...
...
docs/sphinx/requirements.in
View file @
d0f355a3
rocm-docs-core==0.30.
1
rocm-docs-core==0.30.
2
sphinxcontrib-bibtex==2.6.1
sphinxcontrib-bibtex==2.6.1
docs/sphinx/requirements.txt
View file @
d0f355a3
...
@@ -113,7 +113,7 @@ requests==2.31.0
...
@@ -113,7 +113,7 @@ requests==2.31.0
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==0.30.
1
rocm-docs-core==0.30.
2
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via
# via
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
View file @
d0f355a3
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
#include <random>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
...
@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
...
@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
for
(
std
::
size_t
n
=
0
;
n
<
N
;
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
N
;
++
n
)
{
{
ADataType
tmp_val
;
ADataType
tmp_val
;
// auto a_val = A_nchw(n, c, h, w);
auto
a_val
=
A_nchw
.
mData
[(
n
)
+
(
c
*
N
)
+
(
h
*
C
*
N
)
+
(
w
*
H
*
C
*
N
)];
auto
a_val
=
A_nchw
.
mData
[(
n
)
+
(
c
*
N
)
+
(
h
*
C
*
N
)
+
(
w
*
H
*
C
*
N
)];
functor_b
(
tmp_val
,
a_val
);
functor_b
(
tmp_val
,
a_val
);
// functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
functor_a
(
B_nhwc
.
mData
[(
n
)
+
(
c
*
W
*
H
*
N
)
+
(
h
*
N
)
+
(
w
*
H
*
N
)],
functor_a
(
B_nhwc
.
mData
[(
n
)
+
(
c
*
W
*
H
*
N
)
+
(
h
*
N
)
+
(
w
*
H
*
N
)],
scale
*
tmp_val
);
scale
*
tmp_val
);
}
}
...
@@ -62,12 +61,14 @@ int main()
...
@@ -62,12 +61,14 @@ int main()
bool
do_verification
=
true
;
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
std
::
vector
<
std
::
size_t
>
nchw
=
{
4
,
2
,
1
,
8
};
std
::
vector
<
std
::
size_t
>
nchw
=
{
16
,
8
,
32
,
64
};
std
::
vector
<
std
::
size_t
>
nhwc
=
{
4
,
1
,
8
,
2
};
std
::
vector
<
std
::
size_t
>
nhwc
=
{
16
,
32
,
64
,
8
};
Tensor
<
ADataType
>
a
(
nchw
);
Tensor
<
ADataType
>
a
(
nchw
);
Tensor
<
BDataType
>
b
(
nhwc
);
Tensor
<
BDataType
>
b
(
nhwc
);
float
scale
=
1.
f
;
float
scale
=
1.
f
;
auto
i
=
0
;
auto
i
=
0
;
std
::
mt19937
gen
(
11939
);
std
::
uniform_int_distribution
<
int
>
dis
(
0
,
1
);
for
(
std
::
size_t
w
=
0
;
w
<
a
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
w
=
0
;
w
<
a
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
h
=
0
;
h
<
a
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
h
=
0
;
h
<
a
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
c
=
0
;
c
<
a
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
a
.
mDesc
.
GetLengths
()[
1
];
++
c
)
...
@@ -75,7 +76,7 @@ int main()
...
@@ -75,7 +76,7 @@ int main()
{
{
a
.
mData
[(
n
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
(
c
*
nchw
[
2
]
*
nchw
[
3
])
+
a
.
mData
[(
n
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
(
c
*
nchw
[
2
]
*
nchw
[
3
])
+
(
h
*
nchw
[
3
])
+
w
]
=
i
;
(
h
*
nchw
[
3
])
+
w
]
=
i
;
i
++
;
i
=
dis
(
gen
)
;
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
View file @
d0f355a3
...
@@ -67,6 +67,8 @@ int main()
...
@@ -67,6 +67,8 @@ int main()
float
scale
=
1.
f
;
float
scale
=
1.
f
;
auto
i
=
0
;
auto
i
=
0
;
std
::
mt19937
gen
(
11939
);
std
::
uniform_int_distribution
<
int
>
dis
(
0
,
1
);
for
(
std
::
size_t
w
=
0
;
w
<
a
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
w
=
0
;
w
<
a
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
h
=
0
;
h
<
a
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
h
=
0
;
h
<
a
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
c
=
0
;
c
<
a
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
a
.
mDesc
.
GetLengths
()[
1
];
++
c
)
...
@@ -74,7 +76,7 @@ int main()
...
@@ -74,7 +76,7 @@ int main()
{
{
a
.
mData
[(
n
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
(
c
*
nchw
[
2
]
*
nchw
[
3
])
+
a
.
mData
[(
n
*
nchw
[
1
]
*
nchw
[
2
]
*
nchw
[
3
])
+
(
c
*
nchw
[
2
]
*
nchw
[
3
])
+
(
h
*
nchw
[
3
])
+
w
]
=
i
;
(
h
*
nchw
[
3
])
+
w
]
=
i
;
i
++
;
i
=
dis
(
gen
)
;
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
...
...
example/53_layernorm2d_bwd/CMakeLists.txt
0 → 100644
View file @
d0f355a3
add_example_executable
(
example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp
)
example/53_layernorm_bwd/layernorm2d_bwd_fp
16
.cpp
→
example/53_layernorm
2d
_bwd/layernorm2d_bwd_fp
32
.cpp
View file @
d0f355a3
...
@@ -15,16 +15,17 @@
...
@@ -15,16 +15,17 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
using
DYDataType
=
ck
::
half_
t
;
using
DYDataType
=
floa
t
;
using
XDataType
=
ck
::
half_
t
;
using
XDataType
=
floa
t
;
using
GammaDataType
=
ck
::
half_
t
;
using
GammaDataType
=
floa
t
;
using
MeanInvStdDataType
=
float
;
using
MeanInvStdDataType
=
float
;
using
DGammaDataType
=
ck
::
half_
t
;
using
DGammaDataType
=
floa
t
;
using
DBetaDataType
=
ck
::
half_
t
;
using
DBetaDataType
=
floa
t
;
using
DXDataType
=
ck
::
half_
t
;
using
DXDataType
=
floa
t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
...
@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1;
...
@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1;
// inv_std: [M, 1]
// inv_std: [M, 1]
// Output shape
// Output shape
// dx: [M, N]
// dgamma: [1, N]
// dgamma: [1, N]
// dbeta: [1, N]
// dbeta: [1, N]
...
@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1;
...
@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1;
// dbeta = reduce_sum(dy, axis=0)
// dbeta = reduce_sum(dy, axis=0)
// [CAUSION]
// [CAUSION]
// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension
// In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant
// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different
// dimension, K is reduced dimension Hence, M in this example and
// DeviceNormalizationBwdGammaBetaImpl is different
using
XDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdDataImpl
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// MThreadClusterSize
32
,
// KThreadClusterSize
1
,
// MThreadSliceSize
4
,
// KThreadSliceSize
true
,
// IsDYFastestDimReduced
4
,
// DYSrcVectorSize
true
,
// IsXFastestDimReduced
4
,
// XSrcVectorSize
true
,
// IsGammaFastestDimReduced
4
,
// GammaSrcVectorSize
false
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
true
,
// IsDXFastestDimReduced
4
>
;
// DXDstVectorSize
using
GammaBetaDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBetaImpl
<
using
GammaBetaDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBetaImpl
<
DYDataType
,
DYDataType
,
XDataType
,
XDataType
,
...
@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
...
@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
256
,
// BlockSize
256
,
// BlockSize
8
,
//
ClusterInvarient
8
,
//
MThreadClusterSize
32
,
// Cluster
Reduc
e
32
,
//
KThread
Cluster
Siz
e
8
,
//
SliceInvarient
4
,
//
MThreadSliceSize
1
,
//
SliceReduc
e
1
,
//
KThreadSliceSiz
e
false
,
// IsDYFastestDimReduced
false
,
// IsDYFastestDimReduced
8
,
// DYSrcVectorSize
4
,
// DYSrcVectorSize
false
,
// IsXFastestDimReduced
false
,
// IsXFastestDimReduced
8
,
// XSrcVectorSize
4
,
// XSrcVectorSize
true
,
// IsMeanInvStdFastestDimReduced
true
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
1
,
// MeanInvStdSrcVectorSize
1
,
// DGammaDstVectorSize
4
,
// DGammaDstVectorSize
1
>
;
// DBetaDstVectorSize
4
>
;
// DBetaDstVectorSize
int
main
()
int
main
()
{
{
...
@@ -96,16 +124,48 @@ int main()
...
@@ -96,16 +124,48 @@ int main()
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dx_dev
(
sizeof
(
DXDataType
)
*
dx
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
// backward x
auto
x_device_instance
=
XDeviceInstance
{};
auto
x_argument_ptr
=
x_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// lengths
{
N
,
1
},
// dyStrides
{
N
,
1
},
// xStrides
{
0
,
1
},
// gammaStrides
{
1
,
0
},
// meanStrides
{
1
,
0
},
// invStdStrides
{
N
,
1
},
// dxStrides
{
1
},
// reduceDims
dy_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
mean_dev
.
GetDeviceBuffer
(),
inv_std_dev
.
GetDeviceBuffer
(),
dx_dev
.
GetDeviceBuffer
());
if
(
!
x_device_instance
.
IsSupportedArgument
(
x_argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters are not supported."
<<
__FILE__
<<
":"
<<
__LINE__
<<
std
::
endl
;
return
1
;
};
auto
x_invoker_ptr
=
x_device_instance
.
MakeInvokerPointer
();
x_invoker_ptr
->
Run
(
x_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
// backward gamma & beta
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_argument_ptr
=
auto
gamma_beta_argument_ptr
=
gamma_beta_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// inLengths
gamma_beta_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// inLengths
...
@@ -126,7 +186,8 @@ int main()
...
@@ -126,7 +186,8 @@ int main()
if
(
!
gamma_beta_device_instance
.
IsSupportedArgument
(
gamma_beta_argument_ptr
.
get
()))
if
(
!
gamma_beta_device_instance
.
IsSupportedArgument
(
gamma_beta_argument_ptr
.
get
()))
{
{
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
std
::
cout
<<
"The runtime parameters are not supported."
<<
__FILE__
<<
":"
<<
__LINE__
<<
std
::
endl
;
return
1
;
return
1
;
};
};
...
@@ -156,9 +217,11 @@ int main()
...
@@ -156,9 +217,11 @@ int main()
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dx_dev
.
FromDevice
(
dx
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dx
,
host_dx
,
"Error: Incorrect dx"
,
1e-3
,
1e-3
);
}
}
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
...
...
example/53_layernorm_bwd/CMakeLists.txt
deleted
100644 → 0
View file @
55a89c74
add_example_executable
(
example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp
)
example/54_groupnorm_bwd/CMakeLists.txt
View file @
d0f355a3
add_example_executable
(
example_groupnorm_bwd_fp
16
groupnorm_bwd_fp
16
.cpp
)
add_example_executable
(
example_groupnorm_bwd_fp
32
groupnorm_bwd_fp
32
.cpp
)
example/54_groupnorm_bwd/groupnorm_bwd_fp
16
.cpp
→
example/54_groupnorm_bwd/groupnorm_bwd_fp
32
.cpp
View file @
d0f355a3
...
@@ -15,23 +15,58 @@
...
@@ -15,23 +15,58 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
using
DYDataType
=
ck
::
half_
t
;
using
DYDataType
=
floa
t
;
using
XDataType
=
ck
::
half_
t
;
using
XDataType
=
floa
t
;
using
GammaDataType
=
ck
::
half_
t
;
using
GammaDataType
=
floa
t
;
using
MeanInvStdDataType
=
float
;
using
MeanInvStdDataType
=
float
;
using
DGammaDataType
=
ck
::
half_
t
;
using
DGammaDataType
=
floa
t
;
using
DBetaDataType
=
ck
::
half_
t
;
using
DBetaDataType
=
floa
t
;
using
DXDataType
=
ck
::
half_
t
;
using
DXDataType
=
floa
t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
constexpr
int
Rank
=
5
;
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
constexpr
int
NumReduceDim
=
3
;
// Grouprnorm
// Grouprnorm
// kernel: M , K
// kernel 1: M , K
// dy: N, H, W, G, C -> N * G, H * W * C
// x: N, H, W, G, C -> N * G, H * W * C
// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C
// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// dx: N, H, W, G, C -> N * G, H * W * C
using
XDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdDataImpl
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// MThreadClusterSize
32
,
// KThreadClusterSize
1
,
// MThreadSliceSize
4
,
// KThreadSliceSize
true
,
// IsDYFastestDimReduced
4
,
// DYSrcVectorSize
true
,
// IsXFastestDimReduced
4
,
// XSrcVectorSize
true
,
// IsGammaFastestDimReduced
4
,
// GammaSrcVectorSize
false
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
true
,
// IsDXFastestDimReduced
4
>
;
// DXDstVectorSize
// kernel 2: M , K
// dy: N, H, W, G, C -> G * C, N * H * W
// dy: N, H, W, G, C -> G * C, N * H * W
// x: N, H, W, G, C -> G * C, N * H * W
// x: N, H, W, G, C -> G * C, N * H * W
// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1
// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1
...
@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
...
@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
256
,
// BlockSize
256
,
// BlockSize
8
,
// ClusterInvari
e
nt
8
,
// ClusterInvari
a
nt
32
,
// ClusterReduce
32
,
// ClusterReduce
8
,
// SliceInvari
e
nt
4
,
// SliceInvari
a
nt
1
,
// SliceReduce
1
,
// SliceReduce
false
,
// IsDYFastestDimReduced
false
,
// IsDYFastestDimReduced
8
,
// DYSrcVectorSize
4
,
// DYSrcVectorSize
false
,
// IsXFastestDimReduced
false
,
// IsXFastestDimReduced
8
,
// XSrcVectorSize
4
,
// XSrcVectorSize
false
,
// IsMeanInvStdFastestDimReduced
false
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
1
,
// MeanInvStdSrcVectorSize
1
,
// DGammaDstVectorSize
4
,
// DGammaDstVectorSize
1
>
;
// DBetaDstVectorSize
4
>
;
// DBetaDstVectorSize
int
main
()
int
main
()
{
{
...
@@ -93,20 +128,55 @@ int main()
...
@@ -93,20 +128,55 @@ int main()
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dx_dev
(
sizeof
(
DXDataType
)
*
dx
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
std
::
vector
<
ck
::
index_t
>
dyStrides
{
dy
.
mDesc
.
GetStrides
().
begin
(),
dy
.
mDesc
.
GetStrides
().
end
()};
std
::
vector
<
ck
::
index_t
>
dyStrides
{
dy
.
mDesc
.
GetStrides
().
begin
(),
dy
.
mDesc
.
GetStrides
().
end
()};
std
::
vector
<
ck
::
index_t
>
xStrides
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()};
std
::
vector
<
ck
::
index_t
>
xStrides
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()};
std
::
vector
<
ck
::
index_t
>
gammaStrides
=
{
0
,
0
,
0
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
meanStrides
=
{
G
,
0
,
0
,
1
,
0
};
std
::
vector
<
ck
::
index_t
>
meanStrides
=
{
G
,
0
,
0
,
1
,
0
};
std
::
vector
<
ck
::
index_t
>
invStdStrides
=
{
G
,
0
,
0
,
1
,
0
};
std
::
vector
<
ck
::
index_t
>
invStdStrides
=
{
G
,
0
,
0
,
1
,
0
};
std
::
vector
<
ck
::
index_t
>
dxStrides
{
dx
.
mDesc
.
GetStrides
().
begin
(),
dx
.
mDesc
.
GetStrides
().
end
()};
// backward x
auto
x_device_instance
=
XDeviceInstance
{};
auto
x_argument_ptr
=
x_device_instance
.
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
dyStrides
,
// dyStrides
xStrides
,
// xStrides
gammaStrides
,
// gammaStrides
meanStrides
,
// meanStrides
invStdStrides
,
// invStdStrides
dxStrides
,
// dxStrides
{
1
,
2
,
4
},
// reduceDims
dy_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
mean_dev
.
GetDeviceBuffer
(),
inv_std_dev
.
GetDeviceBuffer
(),
dx_dev
.
GetDeviceBuffer
());
if
(
!
x_device_instance
.
IsSupportedArgument
(
x_argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters are not supported."
<<
__FILE__
<<
":"
<<
__LINE__
<<
std
::
endl
;
return
1
;
};
auto
x_invoker_ptr
=
x_device_instance
.
MakeInvokerPointer
();
x_invoker_ptr
->
Run
(
x_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
// backward gamma & beta
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_argument_ptr
=
auto
gamma_beta_argument_ptr
=
...
@@ -128,7 +198,8 @@ int main()
...
@@ -128,7 +198,8 @@ int main()
if
(
!
gamma_beta_device_instance
.
IsSupportedArgument
(
gamma_beta_argument_ptr
.
get
()))
if
(
!
gamma_beta_device_instance
.
IsSupportedArgument
(
gamma_beta_argument_ptr
.
get
()))
{
{
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
std
::
cout
<<
"The runtime parameters are not supported."
<<
__FILE__
<<
":"
<<
__LINE__
<<
std
::
endl
;
return
1
;
return
1
;
};
};
...
@@ -158,9 +229,11 @@ int main()
...
@@ -158,9 +229,11 @@ int main()
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dx_dev
.
FromDevice
(
dx
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dx
,
host_dx
,
"Error: Incorrect dx"
,
1e-3
,
1e-3
);
}
}
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
...
...
include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp
0 → 100644
View file @
d0f355a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceNormalizationBwdData
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
void
*
p_dy
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_mean
,
const
void
*
p_invStd
,
void
*
p_dx
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceNormalizationBwdDataPtr
=
std
::
unique_ptr
<
DeviceNormalizationBwdData
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
Rank
,
NumReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp
0 → 100644
View file @
d0f355a3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
View file @
d0f355a3
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is
i
nvari
e
nt dimension, K is reduced dimension
// M is
I
nvari
a
nt dimension, K is reduced dimension
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
Rank
,
Rank
,
NumReduceDim
>
NumReduceDim
>
{
{
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
...
@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static_assert
(
static_assert
(
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"
);
"check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
GridDesc_M
dgamma_grid_desc_m_
;
GridDesc_M
dgamma_grid_desc_m_
;
GridDesc_M
dbeta_grid_desc_m_
;
GridDesc_M
dbeta_grid_desc_m_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
};
};
...
@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
{
{
return
std
::
make_unique
<
Invoker
>
();
return
std
::
make_unique
<
Invoker
>
();
}
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationBwdGammaBetaImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"VectorSize_DY"
<<
DYSrcVectorSize
<<
"_X"
<<
XSrcVectorSize
;
str
<<
"_DGamma"
<<
DGammaDstVectorSize
<<
"_DBeta"
<<
DBetaDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp
View file @
d0f355a3
...
@@ -19,7 +19,7 @@ namespace tensor_operation {
...
@@ -19,7 +19,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// Y = Normalization(X, Beta, Gamma)
// M: Invari
e
nt length
// M: Invari
a
nt length
// K: Reduce length (Calculate mean and variance along K dimension)
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
// Then, M = N, K = C * H * W
...
@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
...
@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
GridDesc_M
save_inv_std_grid_desc_m_
;
GridDesc_M
save_inv_std_grid_desc_m_
;
bool
isSweeponce_
;
bool
isSweeponce_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
index_t
invariant_lowest_length_
;
...
@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
...
@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
}
}
else
else
{
{
printf
(
"!!!! %d
\n
"
,
p_arg_
->
invariant_lowest_length_
);
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
View file @
d0f355a3
...
@@ -108,7 +108,7 @@ namespace tensor_operation {
...
@@ -108,7 +108,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// Y = Normalization(X, Beta, Gamma)
// M: Invari
e
nt length
// M: Invari
a
nt length
// K: Reduce length (Calculate mean and variance along K dimension)
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
// Then, M = N, K = C * H * W
...
@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
...
@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
index_t
MRaw_
;
//
i
nvari
e
nt length
index_t
MRaw_
;
//
I
nvari
a
nt length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
index_t
invariant_lowest_length_
;
index_t
invariant_lowest_length_
;
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp
0 → 100644
View file @
d0f355a3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
View file @
d0f355a3
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t
DBetaDstVectorSize
>
index_t
DBetaDstVectorSize
>
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
{
{
// if we just check ThreadSliceSize
&
VectorSize == 0, the performance may be poor
// if we just check ThreadSliceSize
%
VectorSize == 0, the performance may be poor
(coalesce)
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
...
@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
...
@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
// do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm
static_assert
(
((
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
==
DGammaDstVectorSize
&&
MThreadSliceSize
==
DBetaDstVectorSize
,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
using
DYThreadBufferDimAccessOrder
=
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp
View file @
d0f355a3
...
@@ -16,6 +16,31 @@ namespace ck {
...
@@ -16,6 +16,31 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def groupnorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [N, H, W, G, C], gamma = [1, 1, 1, G, C], x_mean, rstd = [N, 1, 1, G, 1]
// N, H, W, G, C = x.shape
// dx = normalization_input_backward(
// dy, x, gamma, x_mean, rstd, (1, 2, 4), H * W * C)
// dgamma, dbeta = normalization_gamma_beta_backward(
// dy, x, x_mean, rstd, (0, 1, 2))
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L655
template
<
typename
DYDataType
,
template
<
typename
DYDataType
,
typename
XDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp
View file @
d0f355a3
...
@@ -16,6 +16,30 @@ namespace ck {
...
@@ -16,6 +16,30 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_beta_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def layernorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [M, K], gamma = [1, K], x_mean, rstd = [M, 1]
// # dx = [M, K], dgamma, dbeta = [1, K]
// M, K = x.shape
// dx = normalization_input_backward(dy, x, gamma, x_mean, rstd, 1, K)
// dgamma, dbeta = normalization_gamma_beta_backward(dy, x, x_mean, rstd, 0)
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/layer_norm_kernel.cpp#L196
template
<
typename
DYDataType
,
template
<
typename
DYDataType
,
typename
XDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp
0 → 100644
View file @
d0f355a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_groupnorm_bwd_data_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalizationBwdData
<
F32
,
F32
,
F32
,
F32
,
F32
,
5
,
3
>>>&
);
#endif
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdData
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
5
,
3
>>
{
using
DeviceOp
=
DeviceNormalizationBwdData
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
5
,
3
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
DYDataType
,
F32
>
&&
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
MeanInvStdDataType
,
F32
>
&&
is_same_v
<
DXDataType
,
F32
>
)
{
add_device_groupnorm_bwd_data_f32_instances
(
op_ptrs
);
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment