Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8b7aeb35
"model/models/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "4ff8a691bcef296aa976e19d0ba9c7b74ae9f27c"
Commit
8b7aeb35
authored
Jun 29, 2022
by
rocking
Browse files
Implement layernorm kernel and deviceOp
parent
12235112
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
711 additions
and
0 deletions
+711
-0
example/23_softmax/softmax_blockwise.cpp
example/23_softmax/softmax_blockwise.cpp
+2
-0
example/24_layernorm/CMakeLists.txt
example/24_layernorm/CMakeLists.txt
+1
-0
example/24_layernorm/layernorm_blockwise.cpp
example/24_layernorm/layernorm_blockwise.cpp
+104
-0
example/CMakeLists.txt
example/CMakeLists.txt
+1
-0
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
+237
-0
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
+366
-0
No files found.
example/23_softmax/softmax_blockwise.cpp
View file @
8b7aeb35
...
@@ -209,6 +209,8 @@ int main(int argc, char* argv[])
...
@@ -209,6 +209,8 @@ int main(int argc, char* argv[])
auto
device_instance
=
DeviceInstance
{};
auto
device_instance
=
DeviceInstance
{};
std
::
cout
<<
i_inLengths
.
size
()
<<
", "
<<
i_inStrides
.
size
()
<<
std
::
endl
;
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
i_inLengths
,
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
i_inLengths
,
i_inStrides
,
i_inStrides
,
reduceDims
,
reduceDims
,
...
...
example/24_layernorm/CMakeLists.txt
0 → 100644
View file @
8b7aeb35
add_example_executable
(
example_layernorm_blockwise layernorm_blockwise.cpp
)
\ No newline at end of file
example/24_layernorm/layernorm_blockwise.cpp
0 → 100644
View file @
8b7aeb35
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_common_util.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceLayernorm
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
YDataType
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// ClusterM
32
,
// ClusterK
1
,
// SliceM
8
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SrcScalarPerVector
1
,
// AffineVecDim (0=M, 1=K)
1
,
// AffineScalarPerVector
8
>
;
// OutScalarPerVector
int
main
()
{
bool
time_kernel
=
false
;
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
1024
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
};
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
};
Tensor
<
XDataType
>
x
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
GammaDataType
>
gamma
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
BetaDataType
>
beta
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
YDataType
>
y
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
x
.
GenerateTensorValue
(
GeneratorTensor_3
<
XDataType
>
{
0.0
,
1.0
});
gamma
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
0.0
,
1.0
});
beta
.
GenerateTensorValue
(
GeneratorTensor_3
<
BetaDataType
>
{
0.0
,
1.0
});
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpace
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpace
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpace
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpace
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
auto
device_instance
=
DeviceInstance
{};
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
({
M
,
N
},
{
Stride
,
1
},
{
0
,
1
},
{
1
},
1e-4
,
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
());
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
return
1
;
};
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
bool
pass
=
true
;
return
(
pass
?
0
:
1
);
}
example/CMakeLists.txt
View file @
8b7aeb35
...
@@ -42,3 +42,4 @@ add_subdirectory(20_convnd_bwd_weight_xdl)
...
@@ -42,3 +42,4 @@ add_subdirectory(20_convnd_bwd_weight_xdl)
add_subdirectory
(
21_gemm_layernorm
)
add_subdirectory
(
21_gemm_layernorm
)
add_subdirectory
(
22_cgemm
)
add_subdirectory
(
22_cgemm
)
add_subdirectory
(
23_softmax
)
add_subdirectory
(
23_softmax
)
add_subdirectory
(
24_layernorm
)
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
0 → 100644
View file @
8b7aeb35
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
AccDataType
,
typename
YDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
AffineSrcVectorDim
,
index_t
AffineSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceLayernorm
:
public
BaseOperator
{
static_assert
(
((
AffineSrcVectorDim
==
0
&&
MThreadSliceSize
%
AffineSrcVectorSize
==
0
)
||
(
AffineSrcVectorDim
==
1
&&
KThreadSliceSize
%
AffineSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or affine vector sizes configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using
Reduction
=
DeviceReduceMultiBlock
<
XDataType
,
AccDataType
,
YDataType
,
Rank
,
NumReduceDim
,
reduce
::
Add
,
PassThrough
,
// InElementwiseOperation
PassThrough
,
// AccElementwiseOperation
InMemoryDataOperationEnum
::
Set
,
false
,
// PropagateNan
false
,
// OutputIndex
false
,
// HaveIndexInputIfOutputIndex
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
1
>
;
// OutDstVectorSize
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduce
=
GridwiseLayernorm_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
AffineSrcVectorDim
,
AffineSrcVectorSize
,
OutDstVectorSize
,
false
>
;
struct
Argument
:
public
Reduction
::
Argument
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
affineStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
Reduction
::
Argument
(
inLengths
,
inStrides
,
{},
{},
reduceDims
,
0.0
f
,
// alpha
0.0
f
,
// beta
p_x
,
nullptr
,
p_y
,
nullptr
,
PassThrough
{},
PassThrough
{}),
epsilon_
(
epsilon
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
)
{
affineStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
affineStrides
,
reduceDims
);
}
AccDataType
epsilon_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
std
::
vector
<
index_t
>
affineStrides_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
in_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
gamma_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
affineStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
beta_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
affineStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
kernel_main
=
kernel_layernorm
<
GridwiseReduce
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
out_grid_desc_m_k
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
epsilon_
,
arg
.
in_dev_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
out_dev_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
!
Reduction
::
IsSupportedArgument
(
p_arg_
))
{
return
false
;
}
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
OutDstVectorSize
!=
0
)
{
return
false
;
}
// TODO - Check AffineSrcVectorDim and AffineSrcVectorSize
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
affineStrides
,
const
std
::
vector
<
int
>
reduceDims
,
AccDataType
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
)
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
affineStrides
,
reduceDims
,
epsilon
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceLayernorm<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
0 → 100644
View file @
8b7aeb35
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
in_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
out_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
)
{
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
out_grid_desc_m_k
,
block_group_size
,
num_k_block_tile_iteration
,
epsilon
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
);
};
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
AffineSrcVectorDim
,
index_t
AffineSrcVectorSize
,
index_t
OutDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseLayernorm_mk_to_mk
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
KThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
in_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
out_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
)
{
if
constexpr
(
SweepOnce
)
{
num_k_block_tile_iteration
=
1
;
}
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
out_grid_desc_m_k
.
GetElementSpaceSize
());
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
beta_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
out_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
in_square_thread_buf
=
out_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_value_buf
=
mean_square_thread_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
mean_square_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
true
>
(
in_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
AffineSrcVectorDim
,
AffineSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
AffineSrcVectorDim
,
AffineSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc
),
GridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
in_thread_copy_fwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
// E(x), E[x^2], var(x)
int
reduce_length
=
in_grid_desc_m_k
.
GetLength
(
I1
);
index_t
reducedTiles
=
0
;
do
{
threadwise_x_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_square_thread_buf
(
Number
<
offset
>
{})
=
in_thread_buf
(
Number
<
offset
>
{})
*
in_thread_buf
(
Number
<
offset
>
{});
});
});
ThreadwiseSumReduce
::
Reduce
(
in_thread_buf
,
mean_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
in_square_thread_buf
,
mean_square_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
++
reducedTiles
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
mean_thread_buf
(
I
));
mean_thread_buf
(
I
)
=
mean_thread_buf
(
I
)
/
reduce_length
;
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
mean_square_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
// var(x) = E[x^2] - E[x]^2
var_value_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto
thread_copy_tail
=
(
num_k_block_tile_iteration
-
1
)
*
in_thread_copy_fwd_step
;
threadwise_x_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
thread_copy_tail
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
thread_copy_tail
);
threadwise_y_store
.
MoveDstSliceWindow
(
out_grid_desc_m_k
,
thread_copy_tail
);
reducedTiles
=
0
;
do
{
if
constexpr
(
!
SweepOnce
)
{
threadwise_x_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
out_thread_buf
(
Number
<
offset
>
{})
=
(
in_thread_buf
(
Number
<
offset
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_value_buf
(
iM
)
+
epsilon
);
// affine
out_thread_buf
(
Number
<
offset
>
{})
=
out_thread_buf
(
Number
<
offset
>
{})
*
gamma_thread_buf
(
Number
<
offset
>
{})
+
beta_thread_buf
(
Number
<
offset
>
{});
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
,
out_grid_desc_m_k
,
out_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
threadwise_y_store
.
MoveDstSliceWindow
(
out_grid_desc_m_k
,
in_thread_copy_bwd_step
);
++
reducedTiles
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment