Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9608beee
Unverified
Commit
9608beee
authored
Nov 06, 2022
by
arai713
Committed by
GitHub
Nov 06, 2022
Browse files
Merge branch 'develop' into gridwise_2d
parents
d179a12a
8a4253ba
Changes
172
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
726 additions
and
76 deletions
+726
-76
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+1
-0
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
...k/tensor_operation/gpu/element/quantization_operation.hpp
+86
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+8
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+9
-11
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+500
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+43
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+5
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+7
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+18
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp
...ration/gpu/grid/gridwise_normalization_naive_variance.hpp
+2
-2
No files found.
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
9608beee
...
...
@@ -7,6 +7,7 @@
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
0 → 100644
View file @
9608beee
#pragma once
#include "ck/utility/data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template
<
typename
Activation
>
struct
Activation_Mul_Clamp
{
Activation_Mul_Clamp
(
float
multiplier
,
Activation
activationOp
)
:
multiplier_
(
multiplier
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
float
x_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
x_fp32
,
x_fp32
);
float
y_fp32
=
math
::
clamp
(
multiplier_
*
x_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
int32_t
&
x
)
const
{
// We might type_convert to int8 after lambda in someplace
float
x_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
x_fp32
,
x_fp32
);
y
=
math
::
clamp
(
multiplier_
*
x_fp32
,
-
128.
f
,
127.
f
);
}
float
multiplier_
;
Activation
activationOp_
;
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template
<
typename
Activation
>
struct
Add_Activation_Mul_Clamp
{
Add_Activation_Mul_Clamp
(
float
multiplier
,
Activation
activationOp
)
:
multiplier_
(
multiplier
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x1
,
const
int32_t
&
x2
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x1
+
x2
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
multiplier_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
multiplier_
;
Activation
activationOp_
;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
template
<
typename
Activation
>
struct
Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp
(
float
multiplier1
,
float
multiplier2
,
Activation
activationOp
)
:
multiplier1_
(
multiplier1
),
multiplier2_
(
multiplier2
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x1
,
const
int32_t
&
x2
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x1
+
x2
);
y_fp32
=
multiplier1_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
multiplier2_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
multiplier1_
;
float
multiplier2_
;
Activation
activationOp_
;
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
9608beee
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -74,7 +74,8 @@ template <typename FloatAB,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedGemmGemm_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
...
@@ -101,7 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
...
...
@@ -486,8 +488,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -77,7 +77,8 @@ template <typename FloatAB,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
MaskOutUpperTriangle
>
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
...
@@ -108,7 +109,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
...
...
@@ -503,8 +505,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
@@ -691,12 +694,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
0 → 100644
View file @
9608beee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
// X = Elementwise(input1, input2, input3, ...)
// Y = Normalization(X, beta, gamma)
template
<
typename
InDataTypePointerTuple
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
XElementwiseOperation
,
typename
YElementwiseOperation
,
typename
InGrid2dDescTuple
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
YDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInput
=
InDataTypePointerTuple
::
Size
();
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
XThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
GammaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
GammaSrcVectorSize
>
{};
static
constexpr
auto
BetaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
BetaSrcVectorSize
>
{};
static
constexpr
auto
YThreadBufferNumber
=
Number
<
KThreadSliceSize
/
YDstVectorSize
>
{};
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
int
thread_k_cluster_id
)
{
int
kPerBlock
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
int
kPerThread
=
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
if
(
kPerBlockTail
>
0
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
int
thread_max_len
=
(
thread_k_cluster_id
+
1
)
*
XSrcVectorSize
+
K_BlockTileStepSize
*
i
;
int
delta
=
thread_max_len
-
kPerBlockTail
;
delta
=
math
::
clamp
(
thread_max_len
-
kPerBlockTail
,
0
,
XSrcVectorSize
);
kPerThread
+=
XSrcVectorSize
-
delta
;
});
}
return
kPerThread
;
}
__device__
static
void
Run
(
const
InGrid2dDescTuple
in_grid_2d_desc_tuple
,
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
InDataTypePointerTuple
p_in_global_tuple
,
XDataType
*
const
__restrict__
p_x_lds
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
XElementwiseOperation
x_elementwise_op
,
const
YElementwiseOperation
y_elementwise_op
)
{
if
constexpr
(
SweepOnce
)
{
num_k_block_tile_iteration
=
1
;
}
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
grid_size
=
get_grid_size
();
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
static_assert
(
in_grid_2d_desc_tuple
[
I
].
GetNumOfDimension
()
==
2
);
// matrix dimension
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_2d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumInput
>
{});
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
x_lds_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_x_lds
,
x_grid_desc_m_k
.
GetElementSpaceSize
()
/
grid_size
);
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
)
{
return
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
NumInput
>
{});
},
Number
<
XThreadBufferNumber
>
{});
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
XThreadBufferNumber
>
{});
auto
gamma_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
{};
},
Number
<
GammaThreadBufferNumber
>
{});
auto
beta_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
BetaSrcVectorSize
,
true
>
{};
},
Number
<
BetaThreadBufferNumber
>
{});
auto
y_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
YDstVectorSize
,
true
>
{};
},
Number
<
YThreadBufferNumber
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
auto
in_global_load_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
AccDataType
,
decltype
(
in_grid_2d_desc_tuple
[
I
]),
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
false
>
{
in_grid_2d_desc_tuple
[
I
],
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
XSrcVectorSize
)};
},
Number
<
NumInput
>
{});
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
XSrcVectorSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
GammaSrcVectorSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
BetaSrcVectorSize
));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
PassThrough
pass_through_op
;
auto
threadwise_x_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
XDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
YElementwiseOperation
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
XSrcVectorSize
),
pass_through_op
);
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
YElementwiseOperation
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
YDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
// input load loop
in_global_load_tuple
(
I
).
Run
(
in_grid_2d_desc_tuple
[
I
],
in_global_buf_tuple
[
I
],
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
in_thread_buf_tuple
(
iK0
)(
I
));
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_2d_desc_tuple
[
I
],
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// input add loop
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// get reference to in data
const
auto
in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
const
auto
&
{
return
in_thread_buf_tuple
(
iK0
)(
I
)(
Number
<
offset_m_k
>
{});
},
Number
<
NumInput
>
{});
// get reference to dst data
auto
out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
)
->
auto
&
{
return
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
},
I1
);
unpack2
(
x_elementwise_op
,
out_data_refs
,
in_data_refs
);
});
});
threadwise_welford
.
Run
(
x_thread_buf
[
iK0
],
mean_thread_buf
,
var_thread_buf
);
if
constexpr
(
!
SweepOnce
)
{
threadwise_x_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
iK0
),
x_grid_desc_m_k
,
x_lds_val_buf
);
threadwise_x_store
.
MoveDstSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
}
});
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
});
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
XThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
if
constexpr
(
!
SweepOnce
)
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
if
constexpr
(
!
SweepOnce
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_lds_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
}
static_for
<
0
,
GammaThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
__builtin_amdgcn_sqrtf
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
BetaThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
i
));
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
YThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
(
i
),
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
if
constexpr
(
!
SweepOnce
)
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -169,7 +169,8 @@ template <typename FloatAB,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -189,7 +190,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -526,7 +528,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -68,7 +68,8 @@ template <typename FloatAB,
typename
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock
,
index_t
CDEReduceThreadTransferScalarPerVector_NPerBlock
,
index_t
RThreadTransferDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -91,7 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -495,7 +497,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
...
...
@@ -66,7 +66,8 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmMultipleD_xdl_cshuffle
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -88,7 +89,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -489,7 +491,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
0 → 100644
View file @
9608beee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
namespace
ck
{
enum
struct
PipelineVersion
{
v1
,
v2
,
};
template
<
PipelineVersion
PipelineVer
,
index_t
NumPrefetch
=
1
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
constexpr
auto
GridwiseGemmPipeline_Selector
()
{
if
constexpr
(
PipelineVer
==
PipelineVersion
::
v1
)
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
GridwiseGemmPipelineInterwave_v1
<
NumPrefetch
>
{};
}
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
v2
)
{
return
GridwiseGemmPipeline_v2
{};
}
else
{
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
9608beee
...
...
@@ -352,6 +352,7 @@ struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
{
};
// TODO: deprecate as GridwiseGemmPipeline_Selector covers the functionality
template
<
index_t
NumPrefetch
,
LoopScheduler
LoopSched
>
constexpr
auto
GridwiseGemmPipeline_v1_Selector
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -142,7 +142,8 @@ template <typename FloatAB,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -162,7 +163,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -481,7 +483,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
9608beee
...
...
@@ -8,8 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -115,7 +114,8 @@ template <typename FloatAB,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -136,13 +136,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using
GridwiseGemmPipe
=
#if 1
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
#else
GridwiseGemmPipeline_v2
;
#endif
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -151,7 +151,8 @@ template <typename FloatAB,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadCopySrcDstScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -171,7 +172,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -519,7 +521,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -243,7 +243,8 @@ template <index_t BlockSize,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
ABlockLdsExtraM1Wrw
=
false
,
bool
BBlockLdsExtraN1Wrw
=
false
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -258,8 +259,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// M0/M1/M1Padding
static
constexpr
auto
M1PerBlock
=
Number
<
ABlockLdsM1PerBlock
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...
...
@@ -109,7 +109,9 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -126,7 +128,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
@@ -423,18 +426,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatA
B
,
FloatAcc
,
decltype
(
a
_block_desc_k0_
m
_k1
),
decltype
(
b_block_desc_k0_n_k1
)
,
M
PerXDL
,
NPerXDL
,
M
XdlPerWave
,
NXdlPerWave
,
K1
>
{}
;
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
FloatAB
,
FloatA
cc
,
decltype
(
a_block_desc_k0_m_k1
)
,
decltype
(
b
_block_desc_k0_
n
_k1
),
MPerXDL
,
N
PerXDL
,
MXdlPerWave
,
N
XdlPerWave
,
K1
,
LoopSched
>
()
;
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
9608beee
...
...
@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -117,7 +117,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -137,7 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp"
...
...
@@ -123,7 +123,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -140,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
9608beee
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp"
...
...
@@ -132,7 +132,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -149,7 +150,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_
layernorm
_naive_variance.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_
normalization
_naive_variance.hpp
View file @
9608beee
...
...
@@ -14,7 +14,7 @@
namespace
ck
{
// Y =
LayerNorm
(X, Beta, Gamma)
// Y =
Normalization
(X, Beta, Gamma)
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -36,7 +36,7 @@ template <typename XDataType,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
bool
SweepOnce
>
struct
Gridwise
Layernorm
NaiveVariance_mk_to_mk
struct
Gridwise
Normalization
NaiveVariance_mk_to_mk
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
...
...
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment