Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6fae79b6
Commit
6fae79b6
authored
Nov 20, 2023
by
rocking
Browse files
Check vector size
parent
ee6cd44d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
7 deletions
+15
-7
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
.../device/impl/device_normalization_bwd_gamma_beta_impl.hpp
+6
-6
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
...d/normalization/gridwise_normalization_bwd_gamma_beta.hpp
+9
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
View file @
6fae79b6
...
@@ -101,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
...
@@ -101,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static_assert
(
static_assert
(
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"
);
"check!"
);
static_assert
(
((
MThreadSliceSize
%
DGammaDstVectorSize
==
0
)
||
(
MThreadSliceSize
%
DBetaDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
View file @
6fae79b6
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
...
@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t
DBetaDstVectorSize
>
index_t
DBetaDstVectorSize
>
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
{
{
// if we just check ThreadSliceSize
&
VectorSize == 0, the performance may be poor
// if we just check ThreadSliceSize
%
VectorSize == 0, the performance may be poor
(coalesce)
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
...
@@ -44,6 +44,14 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
...
@@ -44,6 +44,14 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
==
MeanInvStdSrcVectorSize
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
==
MeanInvStdSrcVectorSize
)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
==
DGammaDstVectorSize
&&
MThreadSliceSize
==
DBetaDstVectorSize
,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
using
DYThreadBufferDimAccessOrder
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment