Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
24af0144
Unverified
Commit
24af0144
authored
Nov 12, 2022
by
Po Yen Chen
Committed by
GitHub
Nov 12, 2022
Browse files
Merge branch 'develop' into gemm_layernorm_welford
parents
961f5e9e
b79bbbc2
Changes
813
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1175 additions
and
177 deletions
+1175
-177
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+67
-114
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp
...gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp
+482
-0
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+500
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+43
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+5
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+7
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+18
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp
...ration/gpu/grid/gridwise_normalization_naive_variance.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
...tion/gpu/grid/gridwise_normalization_welford_variance.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
...k/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
+1
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -77,7 +77,8 @@ template <typename FloatAB,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
MaskOutUpperTriangle
>
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
...
@@ -108,7 +109,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
...
...
@@ -336,36 +338,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
template
<
bool
Pred
>
struct
ElementOpPredicatedResetNaNToMinusInf
;
template
<
>
struct
ElementOpPredicatedResetNaNToMinusInf
<
true
>
{
template
<
typename
ElementOp
,
typename
OutT
,
typename
InT
>
__host__
__device__
void
Run
(
OutT
&
y
,
const
ElementOp
&
op
,
const
InT
&
x
)
{
if
(
ck
::
math
::
isnan
(
x
))
{
y
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
op
(
y
,
x
);
}
}
};
template
<
>
struct
ElementOpPredicatedResetNaNToMinusInf
<
false
>
{
template
<
typename
ElementOp
,
typename
OutT
,
typename
InT
>
__host__
__device__
void
Run
(
OutT
&
y
,
const
ElementOp
&
op
,
const
InT
&
x
)
{
op
(
y
,
x
);
}
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -406,11 +378,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/
gemm1_
n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
const
index_t
gemm1_
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
// A matrix in LDS memory, dst of blockwise copy
...
...
@@ -533,8 +505,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
@@ -627,7 +600,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
b1_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
gemm1_
n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
...
...
@@ -721,12 +694,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
...
...
@@ -745,29 +713,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// decoder lower triangular mask
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
{
if
constexpr
(
MaskOutUpperTriangle
)
auto
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
{
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
,
gemm0_n_block_idx
)
&&
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
+
MPerBlock
-
1
,
gemm0_n_block_idx
))
{
continue
;
}
continue
;
}
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
...
...
@@ -789,60 +744,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
const
index_t
m_global
=
mstart
+
m0_i
*
MPerRepeat
;
const
index_t
acc_idx_m0
=
m0_i
*
n0
*
n2
*
n4
;
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
n0_i
)
{
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
index_t
n_global
=
nstartgroup
+
n4_i
;
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
if
constexpr
(
MaskOutUpperTriangle
)
{
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
}
}
else
{
// ignore m_global;
if
(
c0_matrix_mask
.
IsNOutOfBound
(
n_global
))
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
}
}
});
});
});
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
auto
acc0_thread_origin
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
}
});
}
else
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseBatchrNormForwardWithBlockwiseWelford_
,
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_batchnorm_forward_with_blockwise_welford
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
y_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
GridwiseBatchrNormForwardWithBlockwiseWelford_
::
Run
(
x_grid_desc_m_k
,
y_grid_desc_m_k
,
scale_grid_desc_m
,
bias_grid_desc_m
,
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
num_k_block_tile_iteration
,
epsilon
,
p_x
,
p_scale
,
p_bias
,
y_elementwise_op
,
p_y
,
updateMovingAverage
,
averageFactor
,
resultRunningMean
,
resultRunningVariance
,
saveMeanInvVariance
,
resultSaveMean
,
resultSaveInvVariance
);
};
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseBatchNormForwardWithBlockwiseWelford
{
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcYDstVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
y_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
using
ck
::
math
::
sqrt
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
YElementwiseOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
y_elementwise_op
);
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_bias_load
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasSrcVectorSize
,
1
,
true
>
(
bias_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
scale_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
bias_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
// Step 1: do welford reduction to get mean and variance
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
get_reduce_count_per_thread
(
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
});
// Step 2: do normalization and output y
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_bias_load
.
Run
(
bias_grid_desc_m
,
bias_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_thread_buf
);
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
multiplier
=
scale_thread_buf
[
Number
<
iM
>
{}]
/
sqrt
(
var_thread_buf
[
iM
]
+
epsilon
);
AccDataType
fused_mean_bias
=
bias_thread_buf
[
Number
<
iM
>
{}]
-
mean_thread_buf
[
iM
]
*
multiplier
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
y_thread_buf
(
Number
<
offset
>
{})
=
x_thread_buf
[
Number
<
offset
>
{}]
*
multiplier
+
fused_mean_bias
;
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
,
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
// Step 3: update the moving average of mean and variance (optional)
if
(
updateMovingAverage
&&
thread_k_cluster_id
==
0
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_var_thread_buf
;
auto
running_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
running_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
);
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
averageFactor
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
running_mean_thread_buf
(
I
)
=
running_mean_thread_buf
[
I
]
*
oneMinusAverageFactor
+
mean_thread_buf
[
I
]
*
averageFactor
;
running_var_thread_buf
(
I
)
=
running_var_thread_buf
[
I
]
*
oneMinusAverageFactor
+
var_thread_buf
[
I
]
*
averageFactor
;
});
auto
threadwise_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
,
mean_var_grid_desc_m
,
running_mean_global_buf
);
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
,
mean_var_grid_desc_m
,
running_var_global_buf
);
};
// Step 4: save mean and inv-variance (optional)
if
(
saveMeanInvVariance
&&
thread_k_cluster_id
==
0
)
{
auto
result_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
result_inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveInvVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
sqrt
(
epsilon
+
var_thread_buf
[
I
]);
});
auto
threadwise_mean_inv_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
mean_var_grid_desc_m
,
result_mean_global_buf
);
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
var_thread_buf
,
mean_var_grid_desc_m
,
result_inv_var_global_buf
);
};
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
// X = Elementwise(input1, input2, input3, ...)
// Y = Normalization(X, beta, gamma)
template
<
typename
InDataTypePointerTuple
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
XElementwiseOperation
,
typename
YElementwiseOperation
,
typename
InGrid2dDescTuple
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
YDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInput
=
InDataTypePointerTuple
::
Size
();
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
XThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
GammaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
GammaSrcVectorSize
>
{};
static
constexpr
auto
BetaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
BetaSrcVectorSize
>
{};
static
constexpr
auto
YThreadBufferNumber
=
Number
<
KThreadSliceSize
/
YDstVectorSize
>
{};
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
int
thread_k_cluster_id
)
{
int
kPerBlock
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
int
kPerThread
=
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
if
(
kPerBlockTail
>
0
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
int
thread_max_len
=
(
thread_k_cluster_id
+
1
)
*
XSrcVectorSize
+
K_BlockTileStepSize
*
i
;
int
delta
=
thread_max_len
-
kPerBlockTail
;
delta
=
math
::
clamp
(
thread_max_len
-
kPerBlockTail
,
0
,
XSrcVectorSize
);
kPerThread
+=
XSrcVectorSize
-
delta
;
});
}
return
kPerThread
;
}
__device__
static
void
Run
(
const
InGrid2dDescTuple
in_grid_2d_desc_tuple
,
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
InDataTypePointerTuple
p_in_global_tuple
,
XDataType
*
const
__restrict__
p_x_lds
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
XElementwiseOperation
x_elementwise_op
,
const
YElementwiseOperation
y_elementwise_op
)
{
if
constexpr
(
SweepOnce
)
{
num_k_block_tile_iteration
=
1
;
}
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
grid_size
=
get_grid_size
();
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
static_assert
(
in_grid_2d_desc_tuple
[
I
].
GetNumOfDimension
()
==
2
);
// matrix dimension
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_2d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumInput
>
{});
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
x_lds_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_x_lds
,
x_grid_desc_m_k
.
GetElementSpaceSize
()
/
grid_size
);
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
)
{
return
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
NumInput
>
{});
},
Number
<
XThreadBufferNumber
>
{});
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
XThreadBufferNumber
>
{});
auto
gamma_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
{};
},
Number
<
GammaThreadBufferNumber
>
{});
auto
beta_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
BetaSrcVectorSize
,
true
>
{};
},
Number
<
BetaThreadBufferNumber
>
{});
auto
y_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
YDstVectorSize
,
true
>
{};
},
Number
<
YThreadBufferNumber
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
auto
in_global_load_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
AccDataType
,
decltype
(
in_grid_2d_desc_tuple
[
I
]),
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
false
>
{
in_grid_2d_desc_tuple
[
I
],
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
XSrcVectorSize
)};
},
Number
<
NumInput
>
{});
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
XSrcVectorSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
GammaSrcVectorSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
BetaSrcVectorSize
));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
PassThrough
pass_through_op
;
auto
threadwise_x_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
XDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
PassThrough
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
XSrcVectorSize
),
pass_through_op
);
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
YElementwiseOperation
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
YDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
// input load loop
in_global_load_tuple
(
I
).
Run
(
in_grid_2d_desc_tuple
[
I
],
in_global_buf_tuple
[
I
],
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
in_thread_buf_tuple
(
iK0
)(
I
));
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_2d_desc_tuple
[
I
],
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// input add loop
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// get reference to in data
const
auto
in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
const
auto
&
{
return
in_thread_buf_tuple
(
iK0
)(
I
)(
Number
<
offset_m_k
>
{});
},
Number
<
NumInput
>
{});
// get reference to dst data
auto
out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
)
->
auto
&
{
return
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
},
I1
);
unpack2
(
x_elementwise_op
,
out_data_refs
,
in_data_refs
);
});
});
threadwise_welford
.
Run
(
x_thread_buf
[
iK0
],
mean_thread_buf
,
var_thread_buf
);
if
constexpr
(
!
SweepOnce
)
{
threadwise_x_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
iK0
),
x_grid_desc_m_k
,
x_lds_val_buf
);
threadwise_x_store
.
MoveDstSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
}
});
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
});
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
XThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
if
constexpr
(
!
SweepOnce
)
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
if
constexpr
(
!
SweepOnce
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_lds_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
}
static_for
<
0
,
GammaThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
__builtin_amdgcn_sqrtf
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
BetaThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
i
));
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
YThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
(
i
),
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
if
constexpr
(
!
SweepOnce
)
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -169,7 +169,8 @@ template <typename FloatAB,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -189,7 +190,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -526,7 +528,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
24af0144
...
...
@@ -66,6 +66,7 @@ template <index_t BlockSize,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1Value
,
index_t
M1PerThreadM111
,
index_t
N1PerThreadN111
,
index_t
KPerThread
,
...
...
@@ -96,7 +97,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
static
constexpr
auto
I3
=
Number
<
3
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
AGridDesc_K0_M_K1
{}.
GetLength
(
I2
)
;
static
constexpr
auto
K1
=
Number
<
K1Value
>
{}
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -68,7 +68,8 @@ template <typename FloatAB,
typename
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock
,
index_t
CDEReduceThreadTransferScalarPerVector_NPerBlock
,
index_t
RThreadTransferDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -91,7 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -495,7 +497,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
...
...
@@ -66,7 +66,8 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmMultipleD_xdl_cshuffle
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -88,7 +89,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -489,7 +491,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
namespace
ck
{
enum
struct
PipelineVersion
{
v1
,
v2
,
};
template
<
PipelineVersion
PipelineVer
,
index_t
NumPrefetch
=
1
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
constexpr
auto
GridwiseGemmPipeline_Selector
()
{
if
constexpr
(
PipelineVer
==
PipelineVersion
::
v1
)
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
GridwiseGemmPipelineInterwave_v1
<
NumPrefetch
>
{};
}
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
v2
)
{
return
GridwiseGemmPipeline_v2
{};
}
else
{
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
24af0144
...
...
@@ -352,6 +352,7 @@ struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
{
};
// TODO: deprecate as GridwiseGemmPipeline_Selector covers the functionality
template
<
index_t
NumPrefetch
,
LoopScheduler
LoopSched
>
constexpr
auto
GridwiseGemmPipeline_v1_Selector
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -142,7 +142,8 @@ template <typename FloatAB,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -162,7 +163,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -481,7 +483,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
24af0144
...
...
@@ -8,8 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -115,7 +114,8 @@ template <typename FloatAB,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -136,13 +136,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using
GridwiseGemmPipe
=
#if 1
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
#else
GridwiseGemmPipeline_v2
;
#endif
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -151,7 +151,8 @@ template <typename FloatAB,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadCopySrcDstScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -171,7 +172,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
@@ -519,7 +521,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_
v1_
Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -243,7 +243,8 @@ template <index_t BlockSize,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
ABlockLdsExtraM1Wrw
=
false
,
bool
BBlockLdsExtraN1Wrw
=
false
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -258,8 +259,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// M0/M1/M1Padding
static
constexpr
auto
M1PerBlock
=
Number
<
ABlockLdsM1PerBlock
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...
...
@@ -109,7 +109,9 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -126,7 +128,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
@@ -423,18 +426,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatA
B
,
FloatAcc
,
decltype
(
a
_block_desc_k0_
m
_k1
),
decltype
(
b_block_desc_k0_n_k1
)
,
M
PerXDL
,
NPerXDL
,
M
XdlPerWave
,
NXdlPerWave
,
K1
>
{}
;
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
FloatAB
,
FloatA
cc
,
decltype
(
a_block_desc_k0_m_k1
)
,
decltype
(
b
_block_desc_k0_
n
_k1
),
MPerXDL
,
N
PerXDL
,
MXdlPerWave
,
N
XdlPerWave
,
K1
,
LoopSched
>
()
;
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
24af0144
...
...
@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -117,7 +117,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -137,7 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp"
...
...
@@ -123,7 +123,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -140,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
24af0144
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp"
...
...
@@ -132,7 +132,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -149,7 +150,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_
layernorm
_naive_variance.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_
normalization
_naive_variance.hpp
View file @
24af0144
...
...
@@ -14,7 +14,7 @@
namespace
ck
{
// Y =
LayerNorm
(X, Beta, Gamma)
// Y =
Normalization
(X, Beta, Gamma)
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -36,7 +36,7 @@ template <typename XDataType,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
bool
SweepOnce
>
struct
Gridwise
Layernorm
NaiveVariance_mk_to_mk
struct
Gridwise
Normalization
NaiveVariance_mk_to_mk
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_
layernorm
_welford_variance.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_
normalization
_welford_variance.hpp
View file @
24af0144
...
...
@@ -11,7 +11,7 @@
namespace
ck
{
// Y =
LayerNorm
(X, Beta, Gamma)
// Y =
Normalization
(X, Beta, Gamma)
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -33,7 +33,7 @@ template <typename XDataType,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
bool
SweepOnce
>
struct
Gridwise
Layernorm
WelfordVariance_mk_to_mk
struct
Gridwise
Normalization
WelfordVariance_mk_to_mk
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
View file @
24af0144
...
...
@@ -3,6 +3,7 @@
#pragma once
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
41
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment