Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
56532f77
Commit
56532f77
authored
Mar 01, 2023
by
rocking
Browse files
Add second kernel of normalization splitK
parent
28ebcfe7
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
595 additions
and
45 deletions
+595
-45
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
...tion/gpu/device/impl/device_normalization_splitk_impl.hpp
+175
-44
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
.../grid/normalization/gridwise_normalization_splitk_1st.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
.../grid/normalization/gridwise_normalization_splitk_2nd.hpp
+419
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
View file @
56532f77
...
@@ -12,12 +12,13 @@
...
@@ -12,12 +12,13 @@
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseWelford
1
,
template
<
typename
GridwiseWelford
,
typename
XDataType
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
...
@@ -32,13 +33,64 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
...
@@ -32,13 +33,64 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
)
int32_t
*
const
__restrict__
p_welford_count
)
{
{
GridwiseWelford1
::
Run
(
x_grid_desc_m_k
,
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
mean_var_grid_desc_m_kblock
,
mean_var_grid_desc_m_kblock
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
p_x_global
,
p_x_global
,
p_welford_mean
,
p_welford_mean
,
p_welford_variance
,
p_welford_variance
,
p_welford_count
);
p_welford_count
);
};
template
<
typename
GridwiseWelfordNormalization
,
typename
MeanVarDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountVarGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
>
__global__
void
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
CountVarGridDesc_M_KBlock
count_grid_desc_m_kblock
,
const
XYGammaBetaGridDesc_M_K
x_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
const
MeanVarDataType
*
const
p_mean_global
,
const
MeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
count_grid_desc_m_kblock
,
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_mean_var_count_iteration
,
num_k_block_tile_iteration
,
k_grid_size
,
epsilon
,
p_mean_global
,
p_variance_global
,
p_welford_count_global
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
y_elementwise_op
);
};
};
}
// namespace ck
}
// namespace ck
...
@@ -64,7 +116,7 @@ template <typename XDataType,
...
@@ -64,7 +116,7 @@ template <typename XDataType,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XY
Src
VectorDim
,
index_t
XYVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
...
@@ -184,21 +236,52 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -184,21 +236,52 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
}
}
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
Welford
1MeanVarGridDesc_M_KBlock
=
using
Kernel
1MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
using
GridwiseWelford1
=
GridwiseNormalizationSplitK1st
<
XDataType
,
using
Kernel2MeanVarGridDesc_M_KBlock
=
ComputeDataType
,
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
MeanVarDataType
,
SrcGridDesc_M_K
,
using
Kernel2CountGridDesc_M_KBlock
=
Welford1MeanVarGridDesc_M_KBlock
,
decltype
(
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
BlockSize
,
MThreadClusterSize
,
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
KThreadClusterSize
,
ComputeDataType
,
MThreadSliceSize
,
MeanVarDataType
,
KThreadSliceSize
,
SrcGridDesc_M_K
,
XYSrcVectorDim
,
Kernel1MeanVarGridDesc_M_KBlock
,
XSrcVectorSize
>
;
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYVectorDim
,
XSrcVectorSize
>
;
using
GridwiseWelfordNormalization
=
GridwiseNormalizationSplitK2nd
<
MeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYVectorDim
,
YDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -236,19 +319,19 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -236,19 +319,19 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
numBlockTileIteration_
=
1
;
numBlockTileIteration_
=
1
;
while
(
true
)
while
(
true
)
{
{
int
testKGridSize
_
=
int
testKGridSize
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
// we want the
testK
GridSize_ be not more than 128
// we want the
k
GridSize_ be not more than 128
if
(
testKGridSize
_
<=
128
)
if
(
testKGridSize
<=
128
)
break
;
break
;
++
numBlockTileIteration_
;
++
numBlockTileIteration_
;
};
};
kGridSize_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
kGridSize_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
)
*
kGridSize_
;
gridSize
_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
)
*
kGrid
Size
_
;
numMeanVarCountIteration
_
=
math
::
integer_divide_ceil
(
kGridSize_
,
KThreadCluster
Size
)
;
x_grid_desc_m_k_
=
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
kGridSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
kGridSize_
,
numBlockTileIteration_
);
...
@@ -260,9 +343,17 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -260,9 +343,17 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
mean_var_grid_desc_m_kblock_
=
kernel1_
mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
kGridSize_
);
kGridSize_
);
kernel2_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_BlockTileSize
>
(
MRaw_
,
kGridSize_
);
kernel2_count_grid_desc_m_kblock_
=
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_BlockTileSize
>
(
MRaw_
,
kGridSize_
);
}
}
ComputeDataType
epsilon_
;
ComputeDataType
epsilon_
;
...
@@ -284,6 +375,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -284,6 +375,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
YElementwiseOperation
y_elementwise_op_
;
YElementwiseOperation
y_elementwise_op_
;
int
kGridSize_
;
int
kGridSize_
;
int
numMeanVarCountIteration_
;
int
numBlockTileIteration_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
size_t
gridSize_
;
...
@@ -292,7 +384,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -292,7 +384,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
Welford1MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
index_t
MRaw_
;
// invarient length
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
index_t
KRaw_
;
// reduce length
...
@@ -306,12 +400,24 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -306,12 +400,24 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
arg
.
p_workspace_count_
==
nullptr
)
arg
.
p_workspace_count_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
1
,
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
XDataType
,
XDataType
,
MeanVarDataType
,
MeanVarDataType
,
ComputeDataType
,
ComputeDataType
,
SrcGridDesc_M_K
,
SrcGridDesc_M_K
,
Welford1MeanVarGridDesc_M_KBlock
>
;
Kernel1MeanVarGridDesc_M_KBlock
>
;
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
MeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
>
;
float
avg_time
=
0
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
@@ -320,16 +426,38 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -320,16 +426,38 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
arg
.
mean_var_grid_desc_m_kblock_
,
arg
.
kernel1_
mean_var_grid_desc_m_kblock_
,
arg
.
numBlockTileIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
// TODO - welford2 + elementwise
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel2
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
arg
.
kernel2_count_grid_desc_m_kblock_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
numMeanVarCountIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
kGridSize_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
y_elementwise_op_
);
return
(
avg_time
)
;
return
avg_time
;
};
};
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
...
@@ -390,7 +518,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -390,7 +518,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XY
Src
VectorDim
==
0
)
if
constexpr
(
XYVectorDim
==
0
)
{
{
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
{
{
...
@@ -423,38 +551,41 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -423,38 +551,41 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
constexpr
(
GammaSrcVectorDim
==
0
)
if
constexpr
(
GammaSrcVectorDim
==
0
)
{
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
)
;
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
)
;
return
false
;
}
}
else
// if fastest dim is reduced
else
// if fastest dim is reduced
{
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
)
;
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
)
;
return
false
;
}
}
// if fastest dim is not reduced
// if fastest dim is not reduced
if
constexpr
(
BetaSrcVectorDim
==
0
)
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
)
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
(
false
)
;
return
false
;
}
}
else
// if fastest dim is reduced
else
// if fastest dim is reduced
{
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
)
;
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
(
false
)
;
return
false
;
}
}
if
(
p_arg_
->
kGridSize_
<=
1
)
return
false
;
return
true
;
return
true
;
};
};
...
@@ -507,7 +638,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
...
@@ -507,7 +638,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
str
<<
"DeviceNormalizationImpl<"
<<
BlockSize
<<
","
;
str
<<
"DeviceNormalizationImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XY
Src
VectorDim
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_Beta"
<<
BetaSrcVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_Beta"
<<
BetaSrcVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
View file @
56532f77
...
@@ -203,7 +203,7 @@ struct GridwiseNormalizationSplitK1st
...
@@ -203,7 +203,7 @@ struct GridwiseNormalizationSplitK1st
var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
{
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
0 → 100644
View file @
56532f77
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment