Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
a1e344b1
Unverified
Commit
a1e344b1
authored
May 11, 2023
by
rocking
Committed by
GitHub
May 11, 2023
Browse files
Normalization/split k (#615)
parent
b076a02a
Changes
17
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1514 additions
and
82 deletions
+1514
-82
example/27_layernorm/CMakeLists.txt
example/27_layernorm/CMakeLists.txt
+2
-1
example/27_layernorm/common.hpp
example/27_layernorm/common.hpp
+22
-0
example/27_layernorm/layernorm_fp16.cpp
example/27_layernorm/layernorm_fp16.cpp
+39
-0
example/27_layernorm/layernorm_splitk_fp16.cpp
example/27_layernorm/layernorm_splitk_fp16.cpp
+40
-0
example/27_layernorm/run_layernorm_example.inc
example/27_layernorm/run_layernorm_example.inc
+10
-53
example/42_groupnorm/CMakeLists.txt
example/42_groupnorm/CMakeLists.txt
+1
-0
example/42_groupnorm/common.hpp
example/42_groupnorm/common.hpp
+1
-0
example/42_groupnorm/groupnorm_splitk_fp16.cpp
example/42_groupnorm/groupnorm_splitk_fp16.cpp
+40
-0
example/42_groupnorm/run_groupnorm_example.inc
example/42_groupnorm/run_groupnorm_example.inc
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
...ce/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+24
-25
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
...tion/gpu/device/impl/device_normalization_splitk_impl.hpp
+658
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
...d/normalization/gridwise_normalization_naive_variance.hpp
+0
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
...pu/grid/normalization/gridwise_normalization_selector.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
.../grid/normalization/gridwise_normalization_splitk_1st.hpp
+252
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
.../grid/normalization/gridwise_normalization_splitk_2nd.hpp
+418
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
...normalization/gridwise_normalization_welford_variance.hpp
+0
-0
No files found.
example/27_layernorm/CMakeLists.txt
View file @
a1e344b1
add_example_executable
(
example_layernorm_blockwise layernorm_blockwise.cpp
)
add_example_executable
(
example_layernorm_fp16 layernorm_fp16.cpp
)
add_example_executable
(
example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp
)
example/27_layernorm/common.hpp
0 → 100644
View file @
a1e344b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
example/27_layernorm/layernorm_fp16.cpp
0 → 100644
View file @
a1e344b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
PassThrough
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// ClusterM
32
,
// ClusterK
1
,
// SliceM
8
,
// SliceK
1
,
// XYVectorDim (0=M, 1=K)
8
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
#include "run_layernorm_example.inc"
int
main
()
{
return
run_groupnorm_example
<
DeviceInstance
>
();
}
example/27_layernorm/layernorm_splitk_fp16.cpp
0 → 100644
View file @
a1e344b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationSplitKImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
PassThrough
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// ClusterM
32
,
// ClusterK
1
,
// SliceM
8
,
// SliceK
1
,
// XYVectorDim (0=M, 1=K)
8
,
// XScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
>
;
// YScalarPerVector
#include "run_layernorm_example.inc"
int
main
()
{
return
run_groupnorm_example
<
DeviceInstance
>
();
}
example/27_layernorm/layernorm_
blockwise.cpp
→
example/27_layernorm/
run_
layernorm_
example.inc
View file @
a1e344b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
PassThrough
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// ClusterM
32
,
// ClusterK
1
,
// SliceM
8
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
int
main
()
#pragma once
template
<
typename
DeviceInstance
>
int
run_groupnorm_example
()
{
bool
time_kernel
=
false
;
...
...
@@ -111,6 +63,10 @@ int main()
return
1
;
};
size_t
workspace_sz
=
device_instance
.
GetWorkSpaceSize
(
argument_ptr
.
get
());
DeviceMem
workspace_dev
(
workspace_sz
);
device_instance
.
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace_dev
.
GetDeviceBuffer
());
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
...
...
@@ -133,7 +89,8 @@ int main()
ref_invoker
.
Run
(
ref_argument
);
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
y
,
host_y
,
"Error: Incorrect results
d1
"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
y
,
host_y
,
"Error: Incorrect results"
,
1
e
-
3
,
1
e
-
3
);
}
return
(
pass
?
0
:
1
);
}
example/42_groupnorm/CMakeLists.txt
View file @
a1e344b1
add_example_executable
(
example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp
)
add_example_executable
(
example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp
)
add_example_executable
(
example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp
)
example/42_groupnorm/common.hpp
View file @
a1e344b1
...
...
@@ -12,6 +12,7 @@
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
...
...
example/42_groupnorm/groupnorm_splitk_fp16.cpp
0 → 100644
View file @
a1e344b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
YElementOp
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationSplitKImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YElementOp
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
1
,
// ClusterM
256
,
// ClusterK
1
,
// SliceM
16
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
2
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
2
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
2
,
// BetaScalarPerVector
2
>
;
// OutScalarPerVector
#include "run_groupnorm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
run_groupnorm_example
(
argc
,
argv
);
}
example/42_groupnorm/run_groupnorm_example.inc
View file @
a1e344b1
...
...
@@ -73,6 +73,10 @@ int run_groupnorm_example(int argc, char* argv[])
return
1
;
};
size_t
workspace_sz
=
device_instance
.
GetWorkSpaceSize
(
argument_ptr
.
get
());
DeviceMem
workspace_dev
(
workspace_sz
);
device_instance
.
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace_dev
.
GetDeviceBuffer
());
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
,
true
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
a1e344b1
...
...
@@ -807,7 +807,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// workspace for welford intermediate mean
workspace_size
+=
gemm_welford_size
*
sizeof
(
EMeanVarDataType
)
+
64
;
// workspace for welford intermediate
mean
// workspace for welford intermediate
variance
workspace_size
+=
gemm_welford_size
*
sizeof
(
EMeanVarDataType
)
+
64
;
// workspace for welford intermediate count
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
a1e344b1
...
...
@@ -10,8 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -20,6 +19,10 @@ namespace tensor_operation {
namespace
device
{
// Y = Normalization(X, Beta, Gamma)
// M: Invarient length
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -68,7 +71,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
...
...
@@ -117,10 +119,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
const
auto
inPad_K
=
K_BlockTileSize
*
numBlockTileIteration
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
...
...
@@ -132,7 +133,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return
(
in_grid_desc_m_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
));
struct
Argument
:
public
BaseArgument
{
...
...
@@ -162,26 +163,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
long_index_t
invariant_
total_
length
;
long_index_t
reduce_
total_
length
;
long_index_t
invariant_length
;
long_index_t
reduce_length
;
std
::
tie
(
invariant_
total_
length
,
reduce_
total_
length
)
=
std
::
tie
(
invariant_length
,
reduce_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
reduce_length
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
gridSize_
=
math
::
integer_divide_ceil
(
invariant_length
,
M_BlockTileSize
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
numBlockTileIteration_
);
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
numBlockTileIteration_
);
isSweeponce_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
...
...
@@ -202,7 +199,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
YElementwiseOperation
y_elementwise_op_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
...
...
@@ -286,6 +282,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
YDstVectorSize
!=
0
)
return
false
;
};
}
else
...
...
@@ -295,12 +294,12 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
)
return
false
;
};
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
{
return
false
;
}
};
// if fastest dim is not reduced
if
constexpr
(
GammaSrcVectorDim
==
0
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
0 → 100644
View file @
a1e344b1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_naive_variance.hpp
View file @
a1e344b1
File moved
include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_selector.hpp
View file @
a1e344b1
...
...
@@ -3,8 +3,8 @@
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_welford_variance.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
0 → 100644
View file @
a1e344b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
XDataType
,
typename
ComputeDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarGridDesc_M_KBlock
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
>
struct
GridwiseNormalizationSplitK1st
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
ComputeDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
ThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
int
GetKPerThread
(
int
kRaw
,
int
kGridSize
,
int
block_k_cluster_id
,
int
thread_k_cluster_id
)
{
bool
is_rightmost_block
=
block_k_cluster_id
==
kGridSize
-
1
;
if
(
is_rightmost_block
)
{
int
left_kPerBlock
=
math
::
integer_divide_ceil
(
kRaw
,
kGridSize
);
int
kPerBlock
=
kRaw
%
kGridSize
==
0
?
left_kPerBlock
:
kRaw
%
left_kPerBlock
;
int
kPerThread
=
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
if
(
kPerBlockTail
>
0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
int
thread_max_len
=
(
thread_k_cluster_id
+
1
)
*
XSrcVectorSize
+
K_BlockTileStepSize
*
i
;
int
delta
=
thread_max_len
-
kPerBlockTail
;
delta
=
math
::
clamp
(
thread_max_len
-
kPerBlockTail
,
0
,
XSrcVectorSize
);
kPerThread
+=
XSrcVectorSize
-
delta
;
});
}
return
kPerThread
;
}
else
{
int
kPerBlock
=
math
::
integer_divide_ceil
(
kRaw
,
kGridSize
);
return
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
}
}
// Calculate mean and variance by welford along k dimension
__device__
static
void
Run
(
const
XGridDesc_M_K
&
x_grid_desc_m_k
,
const
MeanVarGridDesc_M_KBlock
&
mean_var_grid_desc_m_kblock
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x_global
,
MeanVarDataType
*
const
p_mean_global
,
MeanVarDataType
*
const
p_variance_global
,
int32_t
*
const
p_welford_count_global
)
{
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
k_grid_size
=
mean_var_grid_desc_m_kblock
.
GetLength
(
I1
);
const
index_t
block_m_cluster_id
=
block_global_id
/
k_grid_size
;
const
index_t
block_k_cluster_id
=
block_global_id
%
k_grid_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
XGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
XSrcVectorSize
));
auto
mean_var_count_store_index
=
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
);
auto
threadwise_welford_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarGridDesc_M_KBlock
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m_kblock
,
mean_var_count_store_index
,
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
auto
var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_variance_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
int
kRaw
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
threadwise_welford
.
max_count_
=
GetKPerThread
(
kRaw
,
k_grid_size
,
block_k_cluster_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
[
i
],
mean_thread_buf
,
var_thread_buf
);
});
}
int
welford_count
=
0
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
// The value of count is same for all I
if
constexpr
(
I
==
MThreadSliceSize
-
1
)
welford_count
=
count
;
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
,
mean_var_grid_desc_m_kblock
,
mean_global_val_buf
);
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
var_thread_buf
,
mean_var_grid_desc_m_kblock
,
var_global_val_buf
);
if
(
block_m_cluster_id
==
0
&&
thread_m_cluster_id
==
0
)
p_welford_count_global
[
block_k_cluster_id
]
=
welford_count
;
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
0 → 100644
View file @
a1e344b1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_welford_variance.hpp
View file @
a1e344b1
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment