Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7e493730
Commit
7e493730
authored
Oct 13, 2022
by
Adam Osewski
Browse files
Merge branch 'develop' into wavelet_model
parents
b89a88b5
40942b90
Changes
114
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1839 additions
and
283 deletions
+1839
-283
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+15
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+44
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+93
-29
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+17
-25
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
..._operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
+55
-53
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
...peration/gpu/grid/gridwise_layernorm_welford_variance.hpp
+170
-114
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
+339
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+1
-0
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+583
-0
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+24
-0
include/ck/utility/ignore.hpp
include/ck/utility/ignore.hpp
+1
-3
include/ck/utility/span.hpp
include/ck/utility/span.hpp
+67
-0
include/ck/utility/transpose_vectors.hpp
include/ck/utility/transpose_vectors.hpp
+17
-21
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
...ry/reference_tensor_operation/cpu/reference_groupnorm.hpp
+191
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
...e/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
+100
-0
library/include/ck/library/tensor_operation_instance/gpu/layernorm.hpp
...de/ck/library/tensor_operation_instance/gpu/layernorm.hpp
+36
-12
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+25
-13
library/include/ck/library/utility/fill.hpp
library/include/ck/library/utility/fill.hpp
+12
-0
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+45
-13
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
7e493730
...
@@ -232,6 +232,21 @@ struct Gelu
...
@@ -232,6 +232,21 @@ struct Gelu
}
}
};
};
struct
Sigmoid
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
};
int32_t
divider_
=
1
;
};
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
7e493730
...
@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
...
@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
return
is_valid
;
return
is_valid
;
}
}
// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
// workgroups assigned to a given gemm problem have top index offsetted to range [0,
// grid_size_per_gemm]
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMap
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
)
{
block_to_ctile_map_
=
block_to_ctile_map
;
block_start_
=
block_start
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
-
block_start_
));
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n
);
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
7e493730
...
@@ -76,7 +76,8 @@ template <typename FloatAB,
...
@@ -76,7 +76,8 @@ template <typename FloatAB,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
>
bool
PadN
,
bool
MaskOutUpperTriangle
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
@@ -97,6 +98,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -97,6 +98,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerXdl
*
MXdlPerWave
);
static
constexpr
auto
Gemm0NWaves
=
NPerBlock
/
(
NPerXdl
*
NXdlPerWave
);
// Gemm1
// Gemm1
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
...
@@ -361,7 +366,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -361,7 +366,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
};
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
...
@@ -377,22 +382,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -377,22 +382,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
{
{
const
auto
a_grid_buf
=
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
conditional_expr
<
PadN
>
(
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
NumericLimits
<
FloatAB
>::
QuietNaN
()),
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
()));
const
auto
b_grid_buf
=
conditional_expr
<
PadN
>
(
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
NumericLimits
<
FloatAB
>::
QuietNaN
()),
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
()));
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -749,10 +745,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -749,10 +745,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// decoder lower triangular mask
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
do
do
{
{
if
constexpr
(
MaskOutUpperTriangle
)
{
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
,
gemm0_n_block_idx
)
&&
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
+
MPerBlock
-
1
,
gemm0_n_block_idx
))
{
continue
;
}
}
// gemm0
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
...
@@ -770,16 +786,63 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -770,16 +786,63 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf
,
acc_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// Acc0 elementwise Op
// do MNK padding or upper triangular masking
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
{
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
#else
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
const
index_t
m_global
=
mstart
+
m0_i
*
MPerRepeat
;
acc_thread_buf
(
i
),
acc_element_op
,
acc_thread_buf
[
i
]);
const
index_t
acc_idx_m0
=
m0_i
*
n0
*
n2
*
n4
;
});
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
n0_i
)
{
#endif
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
index_t
n_global
=
nstartgroup
+
n4_i
;
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
if
constexpr
(
MaskOutUpperTriangle
)
{
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
}
}
else
{
// ignore m_global;
if
(
c0_matrix_mask
.
IsNOutOfBound
(
n_global
))
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
}
}
});
});
});
});
}
else
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
@@ -881,9 +944,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -881,9 +944,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc
c_new
=
FloatGemmAcc
c_new
=
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
running_sum_new
[
iM
];
// O_new
running_sum_new
[
iM
];
// Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf
(
I
)
=
c_new
;
c_thread_buf
(
I
)
=
c_new
;
// O_new
});
});
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
View file @
7e493730
...
@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
...
@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
auto
in_global_buf_tuple
=
generate_tuple
(
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
I
)
{
static_assert
(
in_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
},
...
@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
...
@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
auto
out_global_buf_tuple
=
generate_tuple
(
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
I
)
{
static_assert
(
out_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
},
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
7e493730
...
@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
...
@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
}
// A desc for source in blockwise copy
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
{
...
@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
}
// B desc for source in blockwise copy
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
{
...
@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
}
// E desc for destination in blockwise copy
// E desc for destination in blockwise copy
template
<
typename
EGridDesc
riptor
_M_N
>
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
__host__
__device__
static
constexpr
auto
const
EGridDescriptor_M_N
&
e_grid_desc_m_n
)
Make
EGridDescriptor_
MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_
M_N
&
e_grid_desc_m_n
)
{
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
}
// Ds desc for source in blockwise copy
// Ds desc for source in blockwise copy
template
<
typename
DsGridDesc
riptor
_M_N
>
template
<
typename
DsGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
)
const
DsGridDescriptor_M_N
&
ds_grid_desc_m_n
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
}
// return block_id to E matrix tile idx (m0, n0) mapping
// return block_id to E matrix tile idx (m0, n0) mapping
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
{
...
@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
template
<
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
...
@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
using
DefaultAGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
DefaultBlock2ETileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
>
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
...
@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
CDEElementwiseOperation
&
cde_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
DsGridDesc
riptor
_MBlock_MPerBlock_NBlock_NPerBlock
&
const
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc
riptor
_MBlock_MPerBlock_NBlock_NPerBlock
&
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
block_2_etile_map
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
View file @
7e493730
...
@@ -22,7 +22,6 @@ template <typename XDataType,
...
@@ -22,7 +22,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -30,7 +29,9 @@ template <typename XDataType,
...
@@ -30,7 +29,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
...
@@ -78,13 +79,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -78,13 +79,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -111,11 +113,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -111,11 +113,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
y_thread_buf
;
...
@@ -127,7 +132,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -127,7 +132,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_square_thread_buf
;
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
value
_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
thread
_buf
=
mean_square_thread_buf
;
mean_square_thread_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
@@ -145,11 +150,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -145,11 +150,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -169,27 +171,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -169,27 +171,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
auto
threadwise_gamma_load
=
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
AccDataType
,
GridDesc_K
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_k
),
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_K
,
ThreadBufferLengths_
M_
K
,
Sequence
<
0
>
,
ThreadBufferDimAccessOrder
,
0
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
thread_m_cluster_id
*
MThreadSliceSize
,
AccDataType
,
thread_k_cluster_id
*
KThreadSliceSize
));
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
auto
threadwise_beta_load
=
ThreadBufferLengths_K
,
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
Sequence
<
0
>
,
AccDataType
,
0
,
GridDesc_M_K
,
BetaSrcVectorSize
,
decltype
(
thread_buffer_desc_m_k
),
1
,
ThreadBufferLengths_M_K
,
true
>
(
ThreadBufferDimAccessOrder
,
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
@@ -212,9 +221,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -212,9 +221,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
// Copy x from Cache
// Copy x from Cache
// one pass: fwd, second pass: bwd
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
...
@@ -224,13 +230,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -224,13 +230,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
// E(x), E[x^2], var(x)
// E(x), E[x^2], var(x)
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
// FIXME: Should not hack the transform from deviceOP
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
index_t
reducedTiles
=
0
;
index_t
reducedTiles
=
0
;
do
do
...
@@ -271,17 +278,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -271,17 +278,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
// var(x) = E[x^2] - E[x]^2
// var(x) = E[x^2] - E[x]^2
var_
value
_buf
(
I
)
=
var_
thread
_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
});
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
reducedTiles
=
0
;
reducedTiles
=
0
;
...
@@ -296,10 +302,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -296,10 +302,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
x_thread_buf
);
x_thread_buf
);
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -307,23 +313,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -307,23 +313,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_
value
_buf
(
iM
)
+
epsilon
);
sqrt
(
var_
thread
_buf
(
iM
)
+
epsilon
);
// gamma
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
beta_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -331,11 +335,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -331,11 +335,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
...
@@ -346,8 +348,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -346,8 +348,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
y_global_val_buf
);
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
++
reducedTiles
;
++
reducedTiles
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
View file @
7e493730
...
@@ -19,7 +19,6 @@ template <typename XDataType,
...
@@ -19,7 +19,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -27,7 +26,9 @@ template <typename XDataType,
...
@@ -27,7 +26,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
...
@@ -56,7 +57,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -56,7 +57,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSlice
Size
>
{})));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVector
Size
>
{})));
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
...
@@ -70,32 +71,43 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -70,32 +71,43 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
XThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
GammaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
BetaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
YThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
int
thread_k_cluster_id
)
int
thread_k_cluster_id
)
{
{
int
kPerBlock
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
// FIXME: Should not hack the transform from deviceOP
int
kPerBlock
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
int
kPerThread
=
int
kPerThread
=
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
if
(
kPerBlockTail
>
0
)
if
(
kPerBlockTail
>
0
)
{
{
int
thread_max_len
=
(
thread_k_cluster_id
+
1
)
*
KThreadSliceSize
;
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
int
delta
=
thread_max_len
-
kPerBlockTail
;
int
thread_max_len
=
delta
=
math
::
clamp
(
thread_max_len
-
kPerBlockTail
,
0
,
KThreadSliceSize
);
(
thread_k_cluster_id
+
1
)
*
XSrcVectorSize
+
K_BlockTileStepSize
*
i
;
kPerThread
+=
KThreadSliceSize
-
delta
;
int
delta
=
thread_max_len
-
kPerBlockTail
;
delta
=
math
::
clamp
(
thread_max_len
-
kPerBlockTail
,
0
,
XSrcVectorSize
);
kPerThread
+=
XSrcVectorSize
-
delta
;
});
}
}
return
kPerThread
;
return
kPerThread
;
}
}
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -113,16 +125,41 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -113,16 +125,41 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
auto
x_thread_buf
=
generate_tuple
(
x_thread_buf
;
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
AccDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
true
>
{};
gamma_thread_buf
;
},
Number
<
XThreadBufferNumber
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
auto
gamma_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
{};
},
Number
<
GammaThreadBufferNumber
>
{});
auto
beta_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
BetaSrcVectorSize
,
true
>
{};
},
Number
<
BetaThreadBufferNumber
>
{});
auto
y_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
YDstVectorSize
,
true
>
{};
},
Number
<
YThreadBufferNumber
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
...
@@ -136,12 +173,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -136,12 +173,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -156,32 +190,39 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -156,32 +190,39 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_grid_desc_m_k
,
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSlice
Size
));
thread_k_cluster_id
*
XSrcVector
Size
));
auto
threadwise_gamma_load
=
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
AccDataType
,
GridDesc_K
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_k
),
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_K
,
ThreadBufferLengths_
M_
K
,
Sequence
<
0
>
,
ThreadBufferDimAccessOrder
,
0
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
thread_m_cluster_id
*
MThreadSliceSize
,
AccDataType
,
thread_k_cluster_id
*
GammaSrcVectorSize
));
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
auto
threadwise_beta_load
=
ThreadBufferLengths_K
,
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
Sequence
<
0
>
,
AccDataType
,
0
,
GridDesc_M_K
,
BetaSrcVectorSize
,
decltype
(
thread_buffer_desc_m_k
),
1
,
ThreadBufferLengths_M_K
,
true
>
(
ThreadBufferDimAccessOrder
,
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
BetaSrcVectorSize
));
auto
threadwise_y_store
=
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
@@ -199,16 +240,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -199,16 +240,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_grid_desc_m_k
,
y_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSlice
Size
),
thread_k_cluster_id
*
YDstVector
Size
),
acc_elementwise_op
);
acc_elementwise_op
);
// Copy x from Cache
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
@@ -216,10 +251,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -216,10 +251,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
...
@@ -231,14 +266,15 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -231,14 +266,15 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
x_thread_buf
);
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
threadwise_welford
.
Run
(
x_thread_buf
[
i
],
mean_thread_buf
,
var_thread_buf
);
});
}
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
@@ -249,78 +285,98 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -249,78 +285,98 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
});
});
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_m_k
=
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
(
num_k_block_tile_iteration
-
1
)
*
XThreadBufferNumber
*
thread_copy_fwd_step_
m_
k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
{
if
constexpr
(
!
SweepOnce
)
if
constexpr
(
!
SweepOnce
)
{
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
x_global_val_buf
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
thread_buffer_desc_m_k
,
x_global_val_buf
,
make_tuple
(
I0
,
I0
),
thread_buffer_desc_m_k
,
x_thread_buf
);
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
static_for
<
0
,
GammaThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
gamma_global_val_buf
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
thread_buffer_desc_k
,
gamma_global_val_buf
,
make_tuple
(
I0
),
thread_buffer_desc_m_k
,
gamma_thread_buf
);
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
thread_copy_fwd_step_m_k
);
constexpr
auto
offset_m_k
=
});
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
// gamma
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
auto
divisor
=
1
/
__builtin_amdgcn_sqrtf
(
var_thread_buf
(
iM
)
+
epsilon
);
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
static_for
<
0
,
BetaThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
beta_global_val_buf
,
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
thread_buffer_desc_k
,
beta_global_val_buf
,
make_tuple
(
I0
),
thread_buffer_desc_m_k
,
beta_thread_buf
);
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
i
));
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
constexpr
auto
offset_m_k
=
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
// beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
static_for
<
0
,
YThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
make_tuple
(
I0
,
I0
),
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
y_thread_buf
,
make_tuple
(
I0
,
I0
),
y_grid_desc_m_k
,
y_thread_buf
(
i
),
y_global_val_buf
);
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
2
*
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
0 → 100644
View file @
7e493730
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <numeric>
#include <iterator>
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwisePermute
,
typename
InGridDesc
,
typename
OutGridDesc
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
typename
Block2TileMap
>
__global__
void
kernel_nd_permute
(
const
InGridDesc
in_grid_desc
,
const
OutGridDesc
out_grid_desc
,
const
InDataType
*
p_in_global
,
OutDataType
*
p_out_global
,
const
ElementwiseOperation
elementwise_op
,
const
Block2TileMap
block_2_tile_map
)
{
__shared__
char
p_shared
[
GridwisePermute
::
GetSharedMemoryNumberOfByte
()];
GridwisePermute
::
Run
(
in_grid_desc
,
out_grid_desc
,
p_in_global
,
p_out_global
,
p_shared
,
elementwise_op
,
block_2_tile_map
);
}
template
<
typename
InGridDesc
,
typename
OutGridDesc
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
index_t
BlockSize
,
index_t
NPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
InBlockLdsExtraW
,
typename
InBlockTransferThreadClusterLengths
,
typename
InBlockTransferThreadClusterArrangeOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
struct
GridwisePermute
{
static_assert
(
InGridDesc
::
GetNumOfDimension
()
==
OutGridDesc
::
GetNumOfDimension
());
static_assert
(
3
<=
InGridDesc
::
GetNumOfDimension
());
static_assert
((
InGridDesc
::
GetNumOfDimension
()
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
InGridDesc
::
GetNumOfDimension
());
static_assert
((
OutGridDesc
::
GetNumOfDimension
()
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
OutGridDesc
::
GetNumOfDimension
());
static_assert
(
SrcVectorDim
!=
DstVectorDim
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
struct
Block2TileMap
{
static
constexpr
index_t
NumDim
=
InGridDesc
::
GetNumOfDimension
();
static_assert
(
3
<=
NumDim
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
Block2TileMap
()
=
delete
;
Block2TileMap
(
const
Block2TileMap
&
)
=
default
;
Block2TileMap
(
Block2TileMap
&&
)
=
delete
;
~
Block2TileMap
()
=
default
;
Block2TileMap
&
operator
=
(
const
Block2TileMap
&
)
=
delete
;
Block2TileMap
&
operator
=
(
Block2TileMap
&&
)
=
delete
;
explicit
Block2TileMap
(
const
InGridDesc
&
desc
)
:
desc_
(
desc
)
{}
__host__
constexpr
index_t
CalculateGridSize
(
const
InGridDesc
&
desc
)
const
{
const
auto
N0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
3
>
{}),
NPerBlock
);
const
auto
H0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
2
>
{}),
HPerBlock
);
const
auto
W0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}),
WPerBlock
);
const
index_t
grid_size
=
N0
*
H0
*
W0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
static_assert
(
TopIdx
::
Size
()
==
1
);
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
N0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
3
>
{}),
NPerBlock
);
const
auto
H0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
2
>
{}),
HPerBlock
);
const
auto
W0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
1
>
{}),
WPerBlock
);
block_1d_id
=
block_1d_id
%
(
N0
*
H0
*
W0
);
index_t
idx_N0
=
block_1d_id
/
(
H0
*
W0
);
index_t
idx_H0
=
(
block_1d_id
%
(
H0
*
W0
))
/
W0
;
index_t
idx_W0
=
block_1d_id
%
W0
;
return
make_tuple
(
idx_N0
,
idx_H0
,
idx_W0
);
}
private:
const
InGridDesc
desc_
;
};
using
DefaultBlock2TileMap
=
Block2TileMap
;
// use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
__host__
__device__
static
constexpr
auto
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NPerBlock
>
{},
Number
<
HPerBlock
>
{},
Number
<
WPerBlock
>
{}),
make_tuple
(
Number
<
HPerBlock
*
(
WPerBlock
+
InBlockLdsExtraW
)
>
{},
Number
<
WPerBlock
+
InBlockLdsExtraW
>
{},
I1
));
}
// for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
// into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
// [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
template
<
typename
GridDesc
>
__host__
__device__
static
constexpr
auto
GetMergedDesc
(
const
GridDesc
&
desc
)
{
constexpr
index_t
NumDim
=
GridDesc
::
GetNumOfDimension
();
static_assert
(
3
<=
NumDim
);
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
generate_tuple
(
[
&
](
auto
I
)
{
return
desc
.
GetLength
(
I
);
},
Number
<
NumDim
-
2
>
{})),
make_pass_through_transform
(
desc
.
GetLength
(
Number
<
NumDim
-
2
>
{})),
make_pass_through_transform
(
desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}))),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
-
2
>
{}),
Sequence
<
NumDim
-
2
>
{},
Sequence
<
NumDim
-
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
merged_desc
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
in_block_desc_nperblock_hperblock_wperblock
=
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
();
return
in_block_desc_nperblock_hperblock_wperblock
.
GetElementSpaceSize
()
*
sizeof
(
InDataType
);
}
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2TileMap
(
const
InGridDesc
&
desc
)
{
return
DefaultBlock2TileMap
{
desc
};
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
InGridDesc
&
in_grid_desc
,
const
OutGridDesc
&
out_grid_desc
)
{
constexpr
index_t
NumDim
=
InGridDesc
::
GetNumOfDimension
();
// check if we only swap last 2 dimensions
bool
valid
=
true
;
static_for
<
0
,
NumDim
-
2
,
1
>
{}([
&
](
auto
I
)
{
if
(
valid
&&
in_grid_desc
.
GetLength
(
I
)
!=
out_grid_desc
.
GetLength
(
I
))
{
valid
=
false
;
}
});
return
valid
&&
(
in_grid_desc
.
GetLength
(
Number
<
NumDim
-
1
>
{})
==
out_grid_desc
.
GetLength
(
Number
<
NumDim
-
2
>
{}))
&&
(
in_grid_desc
.
GetLength
(
Number
<
NumDim
-
2
>
{})
==
out_grid_desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}));
}
template
<
typename
Block2TileMap
>
__device__
static
void
Run
(
const
InGridDesc
in_grid_desc
,
const
OutGridDesc
out_grid_desc
,
const
InDataType
*
p_in_global
,
OutDataType
*
p_out_global
,
void
*
__restrict__
p_shared
,
const
ElementwiseOperation
elementwise_op
,
const
Block2TileMap
&
block_2_tile_map
)
{
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc
.
GetElementSpaceSize
());
// each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
NPerBlock
);
const
index_t
h_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
HPerBlock
);
const
index_t
w_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
WPerBlock
);
// create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
constexpr
auto
in_block_desc_nperblock_hperblock_wperblock
=
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
();
auto
in_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
InDataType
*>
(
p_shared
),
in_block_desc_nperblock_hperblock_wperblock
.
GetElementSpaceSize
());
using
BlockSliceLengths
=
Sequence
<
NPerBlock
,
HPerBlock
,
WPerBlock
>
;
using
InBlockTransferAccessOrder
=
Sequence
<
0
,
1
,
2
>
;
constexpr
index_t
SrcVectorDimAfterMerge
=
SrcVectorDim
-
(
InGridDesc
::
GetNumOfDimension
()
-
3
);
constexpr
index_t
DstVectorDimAfterMerge
=
SrcVectorDimAfterMerge
;
using
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
// ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
const
auto
in_grid_desc_n_h_w
=
GetMergedDesc
(
in_grid_desc
);
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
auto
in_global_load
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ElementwiseOperation
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
BlockSliceLengths
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
InDataType
,
InDataType
,
decltype
(
in_grid_desc_n_h_w
),
decltype
(
in_block_desc_nperblock_hperblock_wperblock
),
InBlockTransferAccessOrder
,
InBlockTransferAccessOrder
,
SrcVectorDimAfterMerge
,
2
,
SrcScalarPerVector
,
1
,
1
,
1
,
true
,
true
>
(
in_grid_desc_n_h_w
,
make_multi_index
(
n_block_data_idx_on_grid
,
h_block_data_idx_on_grid
,
w_block_data_idx_on_grid
),
PassThrough
{},
in_block_desc_nperblock_hperblock_wperblock
,
make_multi_index
(
0
,
0
,
0
),
PassThrough
{});
// merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
// ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
const
auto
out_grid_desc_n_w_h
=
GetMergedDesc
(
out_grid_desc
);
// create transposed view of output tensor
const
auto
out_grid_desc_n_h_w
=
transform_tensor_descriptor
(
out_grid_desc_n_w_h
,
make_tuple
(
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I0
)),
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I1
)),
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
>
{}));
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
auto
out_global_store
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ElementwiseOperation
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
BlockSliceLengths
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
InDataType
,
OutDataType
,
decltype
(
in_block_desc_nperblock_hperblock_wperblock
),
decltype
(
out_grid_desc_n_h_w
),
InBlockTransferAccessOrder
,
InBlockTransferAccessOrder
,
2
,
DstVectorDimAfterMerge
,
1
,
DstScalarPerVector
,
1
,
1
,
true
,
true
>
(
in_block_desc_nperblock_hperblock_wperblock
,
make_multi_index
(
0
,
0
,
0
),
PassThrough
{},
out_grid_desc_n_h_w
,
make_multi_index
(
n_block_data_idx_on_grid
,
h_block_data_idx_on_grid
,
w_block_data_idx_on_grid
),
elementwise_op
);
in_global_load
.
Run
(
in_grid_desc_n_h_w
,
in_global_buf
,
in_block_desc_nperblock_hperblock_wperblock
,
in_block_buf
,
I0
);
out_global_store
.
Run
(
in_block_desc_nperblock_hperblock_wperblock
,
in_block_buf
,
out_grid_desc_n_h_w
,
out_global_buf
,
I0
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
7e493730
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
0 → 100644
View file @
7e493730
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
namespace
tensor_operation
{
template
<
index_t
NDimSpatial
,
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
ConvBwdDataSpecialization
,
index_t
AK1
,
index_t
BK1
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
bool
DoPadGemmM
,
bool
DoPadGemmN
>
struct
TransformConvBwdDataToGemm_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>,
bool
>::
type
=
false
>
static
auto
MakeADescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
K
=
wei_g_k_c_xs_lengths
[
1
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
AK0
=
K
/
AK1
;
// assume packed
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
AK0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
AK1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>,
bool
>::
type
=
false
>
static
auto
MakeBDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_left_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
K
=
wei_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
BK0
=
K
/
BK1
;
// assume packed
const
auto
wei_k_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
));
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// B: weight tensor
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
C
),
make_tuple
(
I0
,
I1
));
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
BK0
,
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// B weight tensor
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
BK0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
BK1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
}
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
),
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
// assume strided
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
in_g_n_c_wis_strides
[
1
],
in_g_n_c_wis_strides
[
3
],
in_g_n_c_wis_strides
[
4
],
in_g_n_c_wis_strides
[
2
]));
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// C: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
in_n_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
}
// for input bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_C
>
),
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* tildes */
)
{
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
C
),
make_tuple
(
I0
,
I1
));
return
in_gemmm_gemmn_grid_desc
;
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// bias tensor
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
HTildeSlice
*
WTildeSlice
,
C
),
make_tuple
(
I0
,
I1
));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
make_tuple
(
GemmMPerBlock
,
GemmNPerBlock
),
Sequence
<
DoPadGemmM
,
DoPadGemmN
>
{});
return
in_gemmm_gemmn_grid_desc
;
}
}
};
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
7e493730
...
@@ -16,6 +16,7 @@ namespace tensor_operation {
...
@@ -16,6 +16,7 @@ namespace tensor_operation {
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
struct
TransformConvFwdToGemm
struct
TransformConvFwdToGemm
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
@@ -864,6 +865,29 @@ struct TransformConvFwdToGemm
...
@@ -864,6 +865,29 @@ struct TransformConvFwdToGemm
return
out_gemmm_gemmn_desc
;
return
out_gemmm_gemmn_desc
;
}
}
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>
,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
I1
));
return
out_gemmm_gemmn_desc
;
}
};
};
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/utility/ignore.hpp
View file @
7e493730
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_IGNORE_HPP
#pragma once
#define CK_IGNORE_HPP
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
...
@@ -21,4 +20,3 @@ struct ignore_t
...
@@ -21,4 +20,3 @@ struct ignore_t
inline
constexpr
detail
::
ignore_t
ignore
;
inline
constexpr
detail
::
ignore_t
ignore
;
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/span.hpp
0 → 100644
View file @
7e493730
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstddef>
#include <array>
#include <type_traits>
namespace
ck
{
template
<
typename
T
>
class
span
{
public:
using
element_type
=
T
;
using
value_type
=
std
::
remove_cv_t
<
element_type
>
;
using
size_type
=
std
::
size_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
pointer
=
element_type
*
;
using
const_pointer
=
const
element_type
*
;
using
reference
=
element_type
&
;
using
const_reference
=
const
element_type
&
;
using
iterator
=
pointer
;
using
const_iterator
=
pointer
;
constexpr
span
()
:
span
(
nullptr
,
size_type
{
0
})
{}
constexpr
span
(
pointer
first
,
size_type
count
)
:
ptr_
(
first
),
size_
(
count
)
{}
constexpr
span
(
pointer
first
,
pointer
last
)
:
span
(
first
,
last
-
first
)
{}
template
<
std
::
size_t
N
>
constexpr
span
(
element_type
(
&
arr
)[
N
])
noexcept
:
span
(
arr
,
N
)
{
}
template
<
std
::
size_t
N
>
constexpr
span
(
std
::
array
<
value_type
,
N
>&
arr
)
noexcept
:
span
(
arr
.
data
(),
N
)
{
}
template
<
typename
Container
>
constexpr
span
(
const
Container
&
container
)
:
span
(
container
.
data
(),
container
.
size
())
{
}
constexpr
iterator
begin
()
const
noexcept
{
return
ptr_
;
}
constexpr
const_iterator
cbegin
()
const
noexcept
{
return
begin
();
}
constexpr
iterator
end
()
const
noexcept
{
return
begin
()
+
size
();
}
constexpr
const_iterator
cend
()
const
noexcept
{
return
end
();
}
constexpr
reference
front
()
const
{
return
*
begin
();
}
constexpr
reference
back
()
const
{
return
*
(
--
end
());
}
constexpr
reference
operator
[](
size_type
idx
)
const
{
return
*
(
begin
()
+
idx
);
}
constexpr
pointer
data
()
const
noexcept
{
return
ptr_
;
}
constexpr
size_type
size
()
const
noexcept
{
return
size_
;
}
private:
pointer
ptr_
;
size_type
size_
;
};
}
// namespace ck
include/ck/utility/transpose_vectors.hpp
View file @
7e493730
...
@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
...
@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
y0 = vy0.template AsType<half2_t>()[I0];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
#else
asm
volatile
(
"
\n
\
constexpr
int32_t
m0
=
0x05040100
;
v_pack_b32_f16 %0, %1, %2
\n
\
constexpr
int32_t
m1
=
0x07060302
;
"
:
"=v"
(
y0
)
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
:
"v"
(
x0
),
"v"
(
x1
));
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
asm
volatile
(
"
\n
\
// index is reversed because of little endianness (least significant bits first)
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1]
\n
\
y0
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
));
"
y1
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m1
));
:
"=v"
(
y1
)
:
"v"
(
x0
),
"v"
(
x1
));
#endif
#endif
}
}
...
@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0,
...
@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0,
// -- -- -- -- -- -- -- -- - - - -
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
// index is reversed because of little endianness (least significant bits first)
// clang-format off
t0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t0
)
:
"v"
(
bit_cast
<
int32_t
>
(
x1
)),
"v"
(
bit_cast
<
int32_t
>
(
x0
)),
"s"
(
m0
));
t1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x3
),
bit_cast
<
int32_t
>
(
x2
),
m0
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t1
)
:
"v"
(
bit_cast
<
int32_t
>
(
x3
)),
"v"
(
bit_cast
<
int32_t
>
(
x2
)),
"s"
(
m0
));
z0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m1
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z0
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m1
));
z1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m2
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z1
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m2
));
t0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m3
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t0
)
:
"v"
(
bit_cast
<
int32_t
>
(
x1
)),
"v"
(
bit_cast
<
int32_t
>
(
x0
)),
"s"
(
m3
));
t1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x3
),
bit_cast
<
int32_t
>
(
x2
),
m3
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t1
)
:
"v"
(
bit_cast
<
int32_t
>
(
x3
)),
"v"
(
bit_cast
<
int32_t
>
(
x2
)),
"s"
(
m3
));
z2
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m1
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z2
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m1
));
z3
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m2
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z3
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m2
));
// clang-format on
y0
=
bit_cast
<
int8x4_t
>
(
z0
);
y0
=
bit_cast
<
int8x4_t
>
(
z0
);
y1
=
bit_cast
<
int8x4_t
>
(
z1
);
y1
=
bit_cast
<
int8x4_t
>
(
z1
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
0 → 100644
View file @
7e493730
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
>
struct
ReferenceGroupnorm
:
public
device
::
BaseOperator
{
// x = [N, H, W, G, C]
// y = [N, H, W, G, C]
// reduce dim [H, W, C], mean, var = [N, G]
// gamma, beta = [G, C]
// beta: [G, C]
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
XDataType
>&
x
,
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
BetaDataType
>&
beta
,
Tensor
<
YDataType
>&
y
,
AccElementwiseOperation
acc_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
AccDataType
epsilon
)
:
x_
(
x
),
gamma_
(
gamma
),
beta_
(
beta
),
y_
(
y
),
acc_elementwise_op_
(
acc_elementwise_op
),
lengths_
(
lengths
),
epsilon_
(
epsilon
)
{
}
const
Tensor
<
XDataType
>
x_
;
const
Tensor
<
XDataType
>
gamma_
;
const
Tensor
<
XDataType
>
beta_
;
Tensor
<
YDataType
>&
y_
;
AccElementwiseOperation
acc_elementwise_op_
;
std
::
vector
<
index_t
>
lengths_
;
AccDataType
epsilon_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
int
N
=
arg
.
lengths_
[
0
];
int
H
=
arg
.
lengths_
[
1
];
int
W
=
arg
.
lengths_
[
2
];
int
G
=
arg
.
lengths_
[
3
];
int
C
=
arg
.
lengths_
[
4
];
Tensor
<
AccDataType
>
mean
({
N
,
G
});
Tensor
<
AccDataType
>
var
({
N
,
G
});
// Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC
// TODO - address calculation
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
g
=
0
;
g
<
G
;
++
g
)
{
AccDataType
mean_val
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
var_val
=
type_convert
<
AccDataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
AccDataType
delta
=
x
-
mean_val
;
mean_val
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean_val
;
var_val
+=
delta
*
delta2
;
}
}
}
mean
(
n
,
g
)
=
mean_val
;
var
(
n
,
g
)
=
var_val
/
curr_count
;
}
}
// Normalization
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
for
(
int
g
=
0
;
g
<
G
;
++
g
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
AccDataType
gamma
=
type_convert
<
AccDataType
>
(
arg
.
gamma_
(
g
,
c
));
AccDataType
beta
=
type_convert
<
AccDataType
>
(
arg
.
beta_
(
g
,
c
));
AccDataType
mean_val
=
type_convert
<
AccDataType
>
(
mean
(
n
,
g
));
AccDataType
var_val
=
type_convert
<
AccDataType
>
(
var
(
n
,
g
));
AccDataType
y
=
gamma
*
(
x
-
mean_val
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
var_val
)
+
beta
;
arg
.
acc_elementwise_op_
(
y
,
y
);
arg
.
y_
(
n
,
h
,
w
,
g
,
c
)
=
type_convert
<
YDataType
>
(
y
);
}
}
}
}
}
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
p_arg_
->
lengths_
.
size
()
!=
5
)
return
false
;
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
XDataType
>&
x
,
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
BetaDataType
>&
beta
,
Tensor
<
YDataType
>&
y
,
AccElementwiseOperation
acc_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
AccDataType
epsilon
)
{
return
Argument
{
x
,
gamma
,
beta
,
y
,
acc_elementwise_op
,
lengths
,
epsilon
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceLayernorm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
0 → 100644
View file @
7e493730
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CPermuteNumDims_G_M_Gemm1N
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_Gemm1N
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_Gemm1N
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
{
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/layernorm.hpp
View file @
7e493730
...
@@ -17,17 +17,25 @@ namespace tensor_operation {
...
@@ -17,17 +17,25 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
void
add_device_layernorm_f16_rank2_instances
(
// FP16
std
::
vector
<
DeviceLayernormPtr
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
2
,
1
>>&
);
void
add_device_layernorm_rank_2_1_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
2
,
1
>>>&
);
void
add_device_layernorm_
f16_rank4
_instances
(
void
add_device_layernorm_
rank_4_3_f16
_instances
(
std
::
vector
<
DeviceLayernorm
Ptr
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
4
,
3
>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
4
,
3
>>
>
&
);
void
add_device_layernorm_
f32_rank2
_instances
(
void
add_device_layernorm_
rank_5_3_f16
_instances
(
std
::
vector
<
DeviceLayernorm
Ptr
<
F32
,
F
32
,
F
32
,
F32
,
F
32
,
PassThrough
,
2
,
1
>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F
16
,
F
16
,
F32
,
F
16
,
PassThrough
,
5
,
3
>
>>&
);
void
add_device_layernorm_f32_rank4_instances
(
// FP32
std
::
vector
<
DeviceLayernormPtr
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>&
);
void
add_device_layernorm_rank_2_1_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
2
,
1
>>>&
);
void
add_device_layernorm_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
void
add_device_layernorm_rank_5_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
5
,
3
>>>&
);
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
...
@@ -62,17 +70,33 @@ struct DeviceOperationInstanceFactory<
...
@@ -62,17 +70,33 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
add_device_layernorm_f16_rank2_instances
(
op_ptrs
);
{
add_device_layernorm_rank_2_1_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
)
else
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
)
add_device_layernorm_f16_rank4_instances
(
op_ptrs
);
{
add_device_layernorm_rank_4_3_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_layernorm_rank_5_3_f16_instances
(
op_ptrs
);
}
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
{
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
add_device_layernorm_f32_rank2_instances
(
op_ptrs
);
{
add_device_layernorm_rank_2_1_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
)
else
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
)
add_device_layernorm_f32_rank4_instances
(
op_ptrs
);
{
add_device_layernorm_rank_4_3_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_layernorm_rank_5_3_f32_instances
(
op_ptrs
);
}
}
}
return
op_ptrs
;
return
op_ptrs
;
...
...
library/include/ck/library/utility/check_err.hpp
View file @
7e493730
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/host_utility/io.hpp"
...
@@ -32,7 +33,7 @@ check_err(const std::vector<T>& out,
...
@@ -32,7 +33,7 @@ check_err(const std::vector<T>& out,
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -50,7 +51,7 @@ check_err(const std::vector<T>& out,
...
@@ -50,7 +51,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
...
@@ -58,7 +59,7 @@ check_err(const std::vector<T>& out,
...
@@ -58,7 +59,7 @@ check_err(const std::vector<T>& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
@@ -73,7 +74,7 @@ check_err(const std::vector<T>& out,
...
@@ -73,7 +74,7 @@ check_err(const std::vector<T>& out,
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -94,7 +95,7 @@ check_err(const std::vector<T>& out,
...
@@ -94,7 +95,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
...
@@ -102,22 +103,22 @@ check_err(const std::vector<T>& out,
...
@@ -102,22 +103,22 @@ check_err(const std::vector<T>& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
template
<
typename
T
>
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
half_t
>
::
value
,
bool
>::
type
typename
std
::
enable_if
<
std
::
is_same
_v
<
T
,
half_t
>
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>
&
out
,
check_err
(
span
<
const
T
>
out
,
const
std
::
vector
<
T
>
&
ref
,
span
<
const
T
>
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
double
atol
=
1e-3
)
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -137,7 +138,7 @@ check_err(const std::vector<T>& out,
...
@@ -137,7 +138,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
...
@@ -145,11 +146,22 @@ check_err(const std::vector<T>& out,
...
@@ -145,11 +146,22 @@ check_err(const std::vector<T>& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
half_t
>::
value
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>&
out
,
const
std
::
vector
<
T
>&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
return
check_err
(
span
<
const
T
>
{
out
},
span
<
const
T
>
{
ref
},
msg
,
rtol
,
atol
);
}
template
<
typename
T
>
template
<
typename
T
>
std
::
enable_if_t
<
(
std
::
is_integral_v
<
T
>
&&
!
std
::
is_same_v
<
T
,
bhalf_t
>
)
std
::
enable_if_t
<
(
std
::
is_integral_v
<
T
>
&&
!
std
::
is_same_v
<
T
,
bhalf_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
...
@@ -194,7 +206,7 @@ check_err(const std::vector<T>& out,
...
@@ -194,7 +206,7 @@ check_err(const std::vector<T>& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
c
out
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
...
library/include/ck/library/utility/fill.hpp
View file @
7e493730
...
@@ -5,7 +5,10 @@
...
@@ -5,7 +5,10 @@
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include <iterator>
#include <random>
#include <random>
#include <type_traits>
#include <utility>
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
...
@@ -25,6 +28,15 @@ struct FillUniformDistribution
...
@@ -25,6 +28,15 @@ struct FillUniformDistribution
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck
::
type_convert
<
T
>
(
dis
(
gen
));
});
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
->
std
::
void_t
<
decltype
(
std
::
declval
<
FillUniformDistribution
>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
};
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
7e493730
...
@@ -3,15 +3,16 @@
...
@@ -3,15 +3,16 @@
#pragma once
#pragma once
#include <thread>
#include <vector>
#include <numeric>
#include <algorithm>
#include <algorithm>
#include <utility>
#include <cassert>
#include <cassert>
#include <iostream>
#include <iostream>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
template
<
typename
Range
>
template
<
typename
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
...
@@ -235,6 +236,9 @@ auto make_ParallelTensorFunctor(F f, Xs... xs)
...
@@ -235,6 +236,9 @@ auto make_ParallelTensorFunctor(F f, Xs... xs)
template
<
typename
T
>
template
<
typename
T
>
struct
Tensor
struct
Tensor
{
{
using
Descriptor
=
HostTensorDescriptor
;
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
template
<
typename
X
>
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
{
{
...
@@ -251,7 +255,7 @@ struct Tensor
...
@@ -251,7 +255,7 @@ struct Tensor
{
{
}
}
Tensor
(
const
HostTensor
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
template
<
typename
OutT
>
template
<
typename
OutT
>
Tensor
<
OutT
>
CopyAsType
()
const
Tensor
<
OutT
>
CopyAsType
()
const
...
@@ -278,9 +282,9 @@ struct Tensor
...
@@ -278,9 +282,9 @@ struct Tensor
{
{
}
}
const
std
::
vector
<
std
::
size_t
>&
GetLengths
()
const
{
return
mDesc
.
GetLengths
();
}
decltype
(
auto
)
GetLengths
()
const
{
return
mDesc
.
GetLengths
();
}
const
std
::
vector
<
std
::
size_t
>&
GetStrides
()
const
{
return
mDesc
.
GetStrides
();
}
decltype
(
auto
)
GetStrides
()
const
{
return
mDesc
.
GetStrides
();
}
std
::
size_t
GetNumOfDimension
()
const
{
return
mDesc
.
GetNumOfDimension
();
}
std
::
size_t
GetNumOfDimension
()
const
{
return
mDesc
.
GetNumOfDimension
();
}
...
@@ -288,6 +292,8 @@ struct Tensor
...
@@ -288,6 +292,8 @@ struct Tensor
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
void
SetZero
()
void
SetZero
()
{
{
for
(
auto
&
v
:
mData
)
for
(
auto
&
v
:
mData
)
...
@@ -425,14 +431,40 @@ struct Tensor
...
@@ -425,14 +431,40 @@ struct Tensor
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
typename
std
::
vector
<
T
>::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
typename
std
::
vector
<
T
>::
iterator
end
()
{
return
mData
.
end
();
}
typename
Data
::
pointer
data
()
{
return
mData
.
data
();
}
typename
std
::
vector
<
T
>
::
const_iterator
begin
()
const
{
return
mData
.
begin
();
}
typename
Data
::
const_iterator
begin
()
const
{
return
mData
.
begin
();
}
typename
std
::
vector
<
T
>::
const_iterator
end
()
const
{
return
mData
.
end
();
}
typename
Data
::
const_iterator
end
()
const
{
return
mData
.
end
();
}
typename
Data
::
const_pointer
data
()
const
{
return
mData
.
data
();
}
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
template
<
typename
U
=
T
>
auto
AsSpan
()
const
{
constexpr
std
::
size_t
FromSize
=
sizeof
(
T
);
constexpr
std
::
size_t
ToSize
=
sizeof
(
U
);
using
Element
=
std
::
add_const_t
<
std
::
remove_reference_t
<
U
>>
;
return
ck
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
template
<
typename
U
=
T
>
auto
AsSpan
()
{
constexpr
std
::
size_t
FromSize
=
sizeof
(
T
);
constexpr
std
::
size_t
ToSize
=
sizeof
(
U
);
using
Element
=
std
::
remove_reference_t
<
U
>
;
return
ck
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
HostTensor
Descriptor
mDesc
;
Descriptor
mDesc
;
std
::
vector
<
T
>
mData
;
Data
mData
;
};
};
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment