Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4fec5ad3
Commit
4fec5ad3
authored
Oct 28, 2022
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/composable_kernel
into wmma_op
parents
24faa1fc
87fd1152
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2364 additions
and
368 deletions
+2364
-368
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+82
-0
include/ck/tensor_operation/gpu/device/welford_helper.hpp
include/ck/tensor_operation/gpu/device/welford_helper.hpp
+89
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
...orm_multiblock/gridwise_multiblock_welford_first_half.hpp
+258
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp
...ultiblock_welford_second_half_batchnorm_forward_final.hpp
+570
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+58
-103
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp
...gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp
+482
-0
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
...k/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
+1
-0
include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
...ude/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
+59
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+12
-1
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+288
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp
...nsor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp
+102
-71
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp
...tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp
+52
-39
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+6
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
...nsor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
+34
-6
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+129
-0
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp
..._operation_instance/gpu/reduce/device_reduce_instance.hpp
+74
-21
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
..._instance/gpu/reduce/device_reduce_instance_blockwise.hpp
+10
-67
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
...u/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
+0
-59
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp
...duce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp
+27
-0
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp
...uce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp
+31
-0
No files found.
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
enum
struct
MaskingSpecialization
{
MaskDisabled
,
MaskOutUpperTriangle
};
inline
std
::
string
getMaskingSpecializationString
(
const
MaskingSpecialization
&
s
)
{
switch
(
s
)
{
case
MaskingSpecialization
::
MaskDisabled
:
return
"MaskDisabled"
;
case
MaskingSpecialization
::
MaskOutUpperTriangle
:
return
"MaskOutUpperTriangle"
;
default:
return
"Unrecognized specialization!"
;
}
}
struct
MaskDisabledPredicate
{
__host__
__device__
constexpr
bool
operator
()(
index_t
/*m*/
,
index_t
/*n*/
)
const
{
return
false
;
};
__host__
__device__
constexpr
bool
IsTileSkippable
(
index_t
/*m*/
,
index_t
/*n*/
,
index_t
/*m_tile*/
,
index_t
/*n_tile*/
)
const
{
return
false
;
}
};
struct
MaskOutUpperTrianglePredicate
{
__host__
__device__
constexpr
bool
operator
()(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
__host__
__device__
constexpr
bool
IsTileSkippable
(
index_t
m
,
index_t
n
,
index_t
m_tile
,
index_t
/*n_tile*/
)
const
{
return
operator
()(
m
+
m_tile
-
1
,
n
);
}
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
template
<
typename
MaskOutPredicate
>
struct
C0MatrixMask_impl
{
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
__host__
__device__
constexpr
bool
IsMaskedElement
(
index_t
m
,
index_t
n
)
const
{
return
predicate_
(
m
,
n
)
||
IsNOutOfBound
(
n
);
}
__host__
__device__
constexpr
bool
IsTileSkippable
(
index_t
m
,
index_t
n
,
index_t
m_tile
,
index_t
n_tile
)
const
{
return
predicate_
.
IsTileSkippable
(
m
,
n
,
m_tile
,
n_tile
);
}
private:
// index_t MRaw_;
index_t
NRaw_
;
MaskOutPredicate
predicate_
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/welford_helper.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
K_BlockTileSize
,
index_t
KThreadSliceSize
>
struct
GetReduceCountPerThreadForBlockwiseWelford
{
GetReduceCountPerThreadForBlockwiseWelford
(
index_t
numBlockTileIteration
,
long_index_t
reduce_length
)
:
numBlockTileIteration_
{
numBlockTileIteration
}
{
count_in_last_tile_
=
reduce_length
%
K_BlockTileSize
;
};
__device__
index_t
operator
()(
index_t
thread_k_cluster_id
)
const
{
if
(
count_in_last_tile_
==
0
)
return
(
KThreadSliceSize
*
numBlockTileIteration_
);
else
{
index_t
num_complete_slice
=
count_in_last_tile_
/
KThreadSliceSize
;
index_t
count_in_last_slice
=
count_in_last_tile_
%
KThreadSliceSize
;
if
(
thread_k_cluster_id
<
num_complete_slice
)
return
(
KThreadSliceSize
*
numBlockTileIteration_
);
else
if
(
thread_k_cluster_id
==
num_complete_slice
)
return
(
KThreadSliceSize
*
(
numBlockTileIteration_
-
1
)
+
count_in_last_slice
);
else
return
(
KThreadSliceSize
*
(
numBlockTileIteration_
-
1
));
};
};
index_t
numBlockTileIteration_
;
index_t
count_in_last_tile_
;
};
template
<
index_t
K_BlockTileSize
,
index_t
KThreadSliceSize
>
struct
GetReduceCountPerThreadForMultiblockWelford
{
GetReduceCountPerThreadForMultiblockWelford
(
index_t
blkGroupSize
,
index_t
numBlockTileIteration
,
long_index_t
reduce_length
)
:
blkGroupSize_
(
blkGroupSize
),
numBlockTileIteration_
{
numBlockTileIteration
}
{
last_block_reduce_length_
=
reduce_length
-
K_BlockTileSize
*
numBlockTileIteration_
*
(
blkGroupSize_
-
1
);
numBlockTileIterationByLastBlock_
=
(
last_block_reduce_length_
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
__device__
index_t
operator
()(
index_t
block_local_id
,
index_t
thread_k_cluster_id
)
const
{
if
(
last_block_reduce_length_
==
K_BlockTileSize
*
numBlockTileIteration_
||
block_local_id
<
blkGroupSize_
-
1
)
return
(
KThreadSliceSize
*
numBlockTileIteration_
);
index_t
count_in_last_tile
=
last_block_reduce_length_
%
K_BlockTileSize
;
if
(
count_in_last_tile
==
0
)
return
(
KThreadSliceSize
*
numBlockTileIterationByLastBlock_
);
else
{
index_t
num_complete_slice
=
count_in_last_tile
/
KThreadSliceSize
;
if
(
thread_k_cluster_id
<
num_complete_slice
)
return
(
KThreadSliceSize
*
numBlockTileIterationByLastBlock_
);
else
if
(
thread_k_cluster_id
==
num_complete_slice
)
return
(
KThreadSliceSize
*
(
numBlockTileIterationByLastBlock_
-
1
)
+
count_in_last_tile
);
else
return
(
KThreadSliceSize
*
(
numBlockTileIterationByLastBlock_
-
1
));
};
};
index_t
blkGroupSize_
;
index_t
numBlockTileIteration_
;
index_t
last_block_reduce_length_
;
index_t
numBlockTileIterationByLastBlock_
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseMultiblockWelfordFirstHalf_
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_multiblock_welford_first_half
(
const
XGridDesc_M_K
x_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
mean_var_count_grid_desc_m_g
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
p_welford_mean
,
MeanVarDataType
*
const
p_welford_variance
,
int32_t
*
const
p_welford_count
)
{
GridwiseMultiblockWelfordFirstHalf_
::
Run
(
x_grid_desc_m_k
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
num_k_block_tile_iteration
,
p_x
,
p_welford_mean
,
p_welford_variance
,
p_welford_count
);
};
template
<
typename
XDataType
,
typename
AccDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcCountSrcVectorDim
,
index_t
XSrcCountSrcVectorSize
>
struct
GridwiseMultiblockWelfordFirstHalf
{
static_assert
((
XSrcCountSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcCountSrcVectorSize
==
0
)
||
(
XSrcCountSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcCountSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcCountSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XGridDesc_M_K
&
x_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
&
mean_var_count_grid_desc_m_g
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
p_welford_mean
,
MeanVarDataType
*
const
p_welford_variance
,
int32_t
*
const
p_welford_count
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
const
index_t
blkgroup_size
=
mean_var_count_grid_desc_m_g
.
GetLength
(
I1
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcCountSrcVectorDim
,
XSrcCountSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_welford_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_welford_count_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_mean
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_variance
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
get_reduce_count_per_thread
(
block_local_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
,
welford_mean_thread_buf
,
welford_var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
welford_count_thread_buf
(
I
)
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
welford_mean_thread_buf
(
I
),
welford_var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_mean_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_mean_global_val_buf
);
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_var_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_var_global_val_buf
);
threadwise_welford_count_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_count_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_count_global_val_buf
);
};
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
>
__global__
void
kernel_welford_second_half_batchnorm_forward_final
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
y_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_K
mean_var_count_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
const
int32_t
*
const
__restrict__
p_in_welford_count
,
const
XDataType
*
const
__restrict__
p_x
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
GridwiseWelfordSecondHalfBatchNormForwardFinal_
::
Run
(
x_grid_desc_m_k
,
y_grid_desc_m_k
,
mean_var_count_grid_desc_m_k
,
scale_grid_desc_m
,
bias_grid_desc_m
,
mean_var_grid_desc_m
,
blkgroup_size
,
num_xy_k_block_tile_iteration
,
num_mean_var_count_k_block_tile_iteration
,
epsilon
,
p_in_welford_mean
,
p_in_welford_variance
,
p_in_welford_count
,
p_x
,
p_scale
,
p_bias
,
y_elementwise_op
,
p_y
,
updateMovingAverage
,
averageFactor
,
resultRunningMean
,
resultRunningVariance
,
saveMeanInvVariance
,
resultSaveMean
,
resultSaveInvVariance
);
};
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseWelfordSecondHalfBatchNormForwardFinal
{
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcYDstVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_1
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelfordMerge
<
AccDataType
,
ThreadReduceSrcDesc_M_1
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
y_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_K
&
mean_var_count_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
const
int32_t
*
const
__restrict__
p_in_welford_count
,
const
XDataType
*
const
__restrict__
p_x
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
using
ck
::
math
::
sqrt
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
in_welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
in_welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
*
1
,
true
>
in_welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
auto
threadwise_mean_var_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_count_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_mean
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_variance
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
constexpr
auto
mean_var_count_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
// Step 1: do final welford reduction to get mean and variance
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_count_thread_buf
(
I
)
=
0
;
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_mean_var_count_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_mean_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_mean_thread_buf
);
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_var_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_var_thread_buf
);
threadwise_count_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_count_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_count_thread_buf
);
ThreadwiseWelford
::
Run
(
in_welford_mean_thread_buf
,
in_welford_var_thread_buf
,
in_welford_count_thread_buf
,
welford_mean_thread_buf
,
welford_var_thread_buf
,
welford_count_thread_buf
);
threadwise_mean_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_thread_copy_step_m_k
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_thread_copy_step_m_k
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford
::
Run
(
welford_mean_thread_buf
(
I
),
welford_var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
});
// Step 2: do normalization and output y
const
index_t
workSizePerBlock
=
K_BlockTileSize
*
num_xy_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
YElementwiseOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
),
y_elementwise_op
);
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_bias_load
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasSrcVectorSize
,
1
,
true
>
(
bias_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
scale_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
bias_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_bias_load
.
Run
(
bias_grid_desc_m
,
bias_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_thread_buf
);
constexpr
auto
xy_thread_copy_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
for
(
index_t
workTiles
=
0
;
workTiles
<
num_xy_k_block_tile_iteration
;
++
workTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
multiplier
=
scale_thread_buf
[
iM
]
/
sqrt
(
welford_var_thread_buf
[
iM
]
+
epsilon
);
AccDataType
fused_mean_bias
=
bias_thread_buf
[
iM
]
-
welford_mean_thread_buf
[
iM
]
*
multiplier
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
y_thread_buf
(
Number
<
offset
>
{})
=
x_thread_buf
[
Number
<
offset
>
{}]
*
multiplier
+
fused_mean_bias
;
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
,
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
}
// Step 3: update the moving average of mean and variance (optional)
if
(
updateMovingAverage
&&
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_var_thread_buf
;
auto
running_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
running_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_var_load_m
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_var_load_m
.
Run
(
mean_var_grid_desc_m
,
running_mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
);
threadwise_mean_var_load_m
.
Run
(
mean_var_grid_desc_m
,
running_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
);
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
averageFactor
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
running_mean_thread_buf
(
I
)
=
running_mean_thread_buf
[
I
]
*
oneMinusAverageFactor
+
welford_mean_thread_buf
[
I
]
*
averageFactor
;
running_var_thread_buf
(
I
)
=
running_var_thread_buf
[
I
]
*
oneMinusAverageFactor
+
welford_var_thread_buf
[
I
]
*
averageFactor
;
});
auto
threadwise_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
,
mean_var_grid_desc_m
,
running_mean_global_buf
);
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
,
mean_var_grid_desc_m
,
running_var_global_buf
);
};
// Step 4: save mean and inv-variance (optional)
if
(
saveMeanInvVariance
&&
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
auto
result_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
result_inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveInvVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
sqrt
(
epsilon
+
welford_var_thread_buf
[
I
]);
});
auto
threadwise_mean_inv_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
welford_mean_thread_buf
,
mean_var_grid_desc_m
,
result_mean_global_buf
);
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
welford_var_thread_buf
,
mean_var_grid_desc_m
,
result_inv_var_global_buf
);
};
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
4fec5ad3
...
...
@@ -336,36 +336,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
template
<
bool
Pred
>
struct
ElementOpPredicatedResetNaNToMinusInf
;
template
<
>
struct
ElementOpPredicatedResetNaNToMinusInf
<
true
>
{
template
<
typename
ElementOp
,
typename
OutT
,
typename
InT
>
__host__
__device__
void
Run
(
OutT
&
y
,
const
ElementOp
&
op
,
const
InT
&
x
)
{
if
(
ck
::
math
::
isnan
(
x
))
{
y
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
op
(
y
,
x
);
}
}
};
template
<
>
struct
ElementOpPredicatedResetNaNToMinusInf
<
false
>
{
template
<
typename
ElementOp
,
typename
OutT
,
typename
InT
>
__host__
__device__
void
Run
(
OutT
&
y
,
const
ElementOp
&
op
,
const
InT
&
x
)
{
op
(
y
,
x
);
}
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -406,11 +376,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/
gemm1_
n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
const
index_t
gemm1_
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
// A matrix in LDS memory, dst of blockwise copy
...
...
@@ -627,7 +597,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
b1_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
gemm1_
n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
...
...
@@ -745,29 +715,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// decoder lower triangular mask
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
{
if
constexpr
(
MaskOutUpperTriangle
)
auto
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
{
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
(
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
,
gemm0_n_block_idx
)
&&
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
+
MPerBlock
-
1
,
gemm0_n_block_idx
))
{
continue
;
}
continue
;
}
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
...
...
@@ -789,60 +746,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
const
index_t
m_global
=
mstart
+
m0_i
*
MPerRepeat
;
const
index_t
acc_idx_m0
=
m0_i
*
n0
*
n2
*
n4
;
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
n0_i
)
{
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
index_t
n_global
=
nstartgroup
+
n4_i
;
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
if
constexpr
(
MaskOutUpperTriangle
)
{
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
}
}
else
{
// ignore m_global;
if
(
c0_matrix_mask
.
IsNOutOfBound
(
n_global
))
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
}
}
});
});
});
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
auto
acc0_thread_origin
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
}
});
}
else
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseBatchrNormForwardWithBlockwiseWelford_
,
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_batchnorm_forward_with_blockwise_welford
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
y_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
GridwiseBatchrNormForwardWithBlockwiseWelford_
::
Run
(
x_grid_desc_m_k
,
y_grid_desc_m_k
,
scale_grid_desc_m
,
bias_grid_desc_m
,
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
num_k_block_tile_iteration
,
epsilon
,
p_x
,
p_scale
,
p_bias
,
y_elementwise_op
,
p_y
,
updateMovingAverage
,
averageFactor
,
resultRunningMean
,
resultRunningVariance
,
saveMeanInvVariance
,
resultSaveMean
,
resultSaveInvVariance
);
};
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseBatchNormForwardWithBlockwiseWelford
{
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcYDstVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
y_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
using
ck
::
math
::
sqrt
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
YElementwiseOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
y_elementwise_op
);
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_bias_load
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasSrcVectorSize
,
1
,
true
>
(
bias_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
scale_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
bias_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
// Step 1: do welford reduction to get mean and variance
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
get_reduce_count_per_thread
(
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
});
// Step 2: do normalization and output y
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_bias_load
.
Run
(
bias_grid_desc_m
,
bias_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_thread_buf
);
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
multiplier
=
scale_thread_buf
[
Number
<
iM
>
{}]
/
sqrt
(
var_thread_buf
[
iM
]
+
epsilon
);
AccDataType
fused_mean_bias
=
bias_thread_buf
[
Number
<
iM
>
{}]
-
mean_thread_buf
[
iM
]
*
multiplier
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
y_thread_buf
(
Number
<
offset
>
{})
=
x_thread_buf
[
Number
<
offset
>
{}]
*
multiplier
+
fused_mean_bias
;
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
,
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
// Step 3: update the moving average of mean and variance (optional)
if
(
updateMovingAverage
&&
thread_k_cluster_id
==
0
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_var_thread_buf
;
auto
running_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
running_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
);
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
averageFactor
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
running_mean_thread_buf
(
I
)
=
running_mean_thread_buf
[
I
]
*
oneMinusAverageFactor
+
mean_thread_buf
[
I
]
*
averageFactor
;
running_var_thread_buf
(
I
)
=
running_var_thread_buf
[
I
]
*
oneMinusAverageFactor
+
var_thread_buf
[
I
]
*
averageFactor
;
});
auto
threadwise_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
,
mean_var_grid_desc_m
,
running_mean_global_buf
);
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
,
mean_var_grid_desc_m
,
running_var_global_buf
);
};
// Step 4: save mean and inv-variance (optional)
if
(
saveMeanInvVariance
&&
thread_k_cluster_id
==
0
)
{
auto
result_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
result_inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveInvVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
sqrt
(
epsilon
+
var_thread_buf
[
I
]);
});
auto
threadwise_mean_inv_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
mean_var_grid_desc_m
,
result_mean_global_buf
);
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
var_thread_buf
,
mean_var_grid_desc_m
,
result_inv_var_global_buf
);
};
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
View file @
4fec5ad3
...
...
@@ -3,6 +3,7 @@
#pragma once
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
View file @
4fec5ad3
...
...
@@ -75,4 +75,63 @@ struct ThreadwiseWelford
int
max_count_
;
};
template
<
typename
T
,
typename
SrcMeanVarCountThreadDesc_M_K
,
typename
DstMeanVarThreadDesc_M
,
bool
GetActualVariance
=
false
>
struct
ThreadwiseWelfordMerge
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcMeanVarCountThreadDesc_M_K
{};
static
constexpr
auto
dst_thread_desc_m
=
DstMeanVarThreadDesc_M
{};
static
constexpr
auto
src_length_m
=
src_thread_desc_m_k
.
GetLength
(
Number
<
0
>
{});
static
constexpr
auto
src_length_k
=
src_thread_desc_m_k
.
GetLength
(
Number
<
1
>
{});
static
constexpr
auto
dst_length_m
=
dst_thread_desc_m
.
GetLength
(
Number
<
0
>
{});
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
__device__
static
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int32_t
&
count_a
,
T
mean_b
,
T
var_b
,
int32_t
count_b
)
{
int
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a
*
count_b_over_count
;
count_a
=
count
;
}
template
<
typename
SrcMeanBufferType
,
typename
SrcVarBufferType
,
typename
SrcCountBufferType
,
typename
DstMeanBufferType
,
typename
DstVarBufferType
,
typename
DstCountBufferType
>
__device__
static
void
Run
(
const
SrcMeanBufferType
&
src_mean_buf
,
const
SrcVarBufferType
&
src_var_buf
,
const
SrcCountBufferType
&
src_count_buf
,
DstMeanBufferType
&
dst_mean_buf
,
DstVarBufferType
&
dst_var_buf
,
DstCountBufferType
&
dst_count_buf
)
{
static_for
<
0
,
src_length_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
src_length_k
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
src_offset
=
src_thread_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
Merge
(
dst_mean_buf
(
iM
),
dst_var_buf
(
iM
),
dst_count_buf
(
iM
),
src_mean_buf
[
Number
<
src_offset
>
{}],
src_var_buf
[
Number
<
src_offset
>
{}],
src_count_buf
[
Number
<
src_offset
>
{}]);
});
if
constexpr
(
GetActualVariance
)
{
dst_var_buf
(
iM
)
=
dst_var_buf
[
iM
]
/
dst_count_buf
[
iM
];
};
});
};
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
4fec5ad3
...
...
@@ -593,7 +593,8 @@ struct XdlopsGemm
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_instr
.
num_output_blks
;
}
...
...
@@ -822,6 +823,16 @@ struct XdlopsGemm
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
__device__
static
CIndex4D
GetBeginOfThreadBlk4D
(
index_t
/* xdlops_i */
,
index_t
/* blk_i */
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
device
::
TensorSpecialization
TensorSpec
>
static
auto
MakeGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
gs_ms_ns_strides_vec
)
{
if
(
!
(
gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
))
{
throw
std
::
runtime_error
(
"wrong! dimension must match input lengths"
);
}
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
gs_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
gs_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
mDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
TensorSpec
==
device
::
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
const
auto
grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
gs_ms_ns_lengths
,
gs_ms_ns_strides
);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const
auto
grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
c_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
c_ms_ns_lengths
,
c_ms_ns_strides
);
const
auto
grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
-
Number
<
NumDimG
>
{},
nDimIds
-
Number
<
NumDimG
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
}
template
<
typename
NumDims_G_M_N_K_O
,
// Sequence<>
typename
PerBlock_M_N_K_O
,
// Sequence<>
device
::
GemmSpecialization
GemmSpec
,
device
::
TensorSpecialization
ASpec
,
device
::
TensorSpecialization
B0Spec
,
device
::
TensorSpecialization
B1Spec
,
device
::
TensorSpecialization
CSpec
>
struct
TransformBatchedContractionContractionToBatchedGemmGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
index_t
NumDimG
=
NumDims_G_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NumDimM
=
NumDims_G_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
NumDimN
=
NumDims_G_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
NumDimK
=
NumDims_G_M_N_K_O
::
At
(
I3
);
static
constexpr
index_t
NumDimO
=
NumDims_G_M_N_K_O
::
At
(
I4
);
static
constexpr
index_t
MPerBlock
=
PerBlock_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NPerBlock
=
PerBlock_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
KPerBlock
=
PerBlock_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
OPerBlock
=
PerBlock_M_N_K_O
::
At
(
I3
);
static
constexpr
auto
matrix_padder
=
device
::
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
OPerBlock
};
//
// A
//
static
auto
MakeAGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
);
}
// TODO: rename to G_MRaw_KRaw
static
auto
MakeAGridDescriptor_G_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
}
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
matrix_padder
.
PadADescriptor_M_K
(
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
}
template
<
typename
AGridDesc_M_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
AK1
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// B (alias of B0)
//
static
auto
MakeB0GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
static
auto
MakeB0GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
}
static
auto
MakeB0GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return
matrix_padder
.
PadBDescriptor_N_K
(
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
second
);
}
template
<
typename
BGridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
Number
&
BK1
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// B1
//
static
auto
MakeB1GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
);
}
// TODO: rename to G_NRaw_KRaw
static
auto
MakeB1GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
}
static
auto
MakeB1GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return
matrix_padder
.
PadB1Descriptor_N_K
(
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
second
);
}
template
<
typename
B1GridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
,
const
Number
&
B1K1
)
{
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// C
//
static
auto
MakeCGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
static
auto
MakeCGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
}
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
}
};
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp
View file @
4fec5ad3
...
...
@@ -9,46 +9,61 @@
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
InOutDataType
,
typename
AccDataType
>
struct
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormFwd
<
4
,
3
>
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
>
struct
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormFwd
<
4
,
3
,
YElementwiseOp
>
{
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarStrides
,
const
InOutDataType
*
p_x
,
const
AccDataType
*
bnScale
,
const
AccDataType
*
bnBias
,
InOutDataType
*
p_y
,
double
exponentialAverageFactor
,
AccDataType
*
resultRunningMean
,
AccDataType
*
resultRunningVariance
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
double
epsilon
,
AccDataType
*
resultSaveMean
,
AccDataType
*
resultSaveInvVariance
)
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
p_y
,
MeanVarDataType
*
resultSaveMean
,
MeanVarDataType
*
resultSaveInvVariance
,
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
:
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
y_elementwise_op_
(
y_elementwise_op
),
p_y_
(
p_y
),
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
),
resultSaveMean_
(
resultSaveMean
),
resultSaveInvVariance_
(
resultSaveInvVariance
),
exponentialAverageFactor_
(
exponentialAverageFactor
),
epsilon_
(
epsilon
)
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
)
{
(
void
)
xStrides
;
(
void
)
yStrides
;
(
void
)
bnScaleBiasMeanVarStrides
;
ignore
=
xStrides
;
ignore
=
yStrides
;
ignore
=
bnScaleStrides
;
ignore
=
bnBiasStrides
;
ignore
=
bnMeanVarStrides
;
ignore
=
reduceDims
;
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
...
...
@@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
w
=
xyLengths
[
2
];
c
=
xyLengths
[
3
];
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
resultSave
=
(
resultSaveMean
!=
nullptr
&&
resultSaveInvVariance
!=
nullptr
);
resultRunning
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
}
const
InOutDataType
*
p_x_
;
const
AccDataType
*
bnScale_
;
const
AccDataType
*
bnBias_
;
InOutDataType
*
p_y_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
bnScale_
;
const
BiasDataType
*
bnBias_
;
const
YElementwiseOp
y_elementwise_op_
;
YDataType
*
p_y_
;
Acc
DataType
*
result
Running
Mean_
;
Acc
DataType
*
result
Running
Variance_
;
Acc
DataType
*
result
Save
Mean_
;
Acc
DataType
*
result
SaveInv
Variance_
;
MeanVar
DataType
*
result
Save
Mean_
;
MeanVar
DataType
*
result
SaveInv
Variance_
;
MeanVar
DataType
*
result
Running
Mean_
;
MeanVar
DataType
*
result
Running
Variance_
;
bool
resultSave
,
resultRunning
;
index_t
n
,
h
,
w
,
c
;
double
exponentialA
verageFactor_
;
doubl
e
epsilon_
;
AccDataType
a
verageFactor_
;
AccDataTyp
e
epsilon_
;
};
struct
Invoker
:
public
device
::
BaseInvoker
...
...
@@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
float
Run
(
const
Argument
&
arg
)
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
AccDataType
reduceSize
=
type_convert
<
AccDataType
>
(
arg
.
n
)
*
type_convert
<
AccDataType
>
(
arg
.
h
)
*
type_convert
<
AccDataType
>
(
arg
.
w
);
index_t
offset_C
=
iC
;
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
meansquare
=
type_convert
<
AccDataType
>
(
0.0
f
);
// compute mean, meanquare, variance, invVariance
index_t
offset_C
=
iC
;
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
// compute mean, variance using welford method
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
...
...
@@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
mean
+=
x
;
meansquare
+=
x
*
x
;
AccDataType
delta
=
x
-
mean
;
mean
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
};
}
};
mean
=
mean
/
reduceSize
;
meansquare
=
meansquare
/
reduceSize
;
// actual variance
variance
=
variance
/
curr_count
;
AccDataType
variance
=
meansquare
-
mean
*
mean
;
AccDataType
invVariance
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
std
::
sqrt
(
type_convert
<
AccDataType
>
(
arg
.
epsilon_
)
+
variance
);
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
// save the mean/invVariance if required
if
(
arg
.
resultSave
)
{
arg
.
resultSaveMean_
[
iC
]
=
mean
;
arg
.
resultSaveInvVariance_
[
iC
]
=
invVariance
;
arg
.
resultSaveMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
mean
)
;
arg
.
resultSaveInvVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
invVariance
)
;
};
// update the moving average if required
if
(
arg
.
resultRunning
)
{
arg
.
resultRunningMean_
[
iC
]
=
arg
.
resultRunningMean_
[
iC
]
*
type_convert
<
AccDataType
>
(
1.0
-
arg
.
exponentialAverageFactor_
)
+
mean
*
arg
.
exponentialAverageFactor_
;
arg
.
resultRunningVariance_
[
iC
]
=
arg
.
resultRunningVariance_
[
iC
]
*
type_convert
<
AccDataType
>
(
1.0
-
arg
.
exponentialAverageFactor_
)
+
variance
*
arg
.
exponentialAverageFactor_
;
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
arg
.
averageFactor_
;
arg
.
resultRunningMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
type_convert
<
AccDataType
>
(
arg
.
resultRunningMean_
[
iC
])
*
oneMinusAverageFactor
+
mean
*
arg
.
averageFactor_
);
arg
.
resultRunningVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningVariance_
[
iC
]
*
oneMinusAverageFactor
+
variance
*
arg
.
averageFactor_
);
};
// Normalization
...
...
@@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType
norm_x
=
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
p_y_
[
offset
]
=
type_convert
<
InOut
DataType
>
(
norm_x
);
arg
.
p_y_
[
offset
]
=
type_convert
<
Y
DataType
>
(
norm_x
);
};
}
};
...
...
@@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnBias
,
void
*
p_y
,
double
exponentialAverageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
void
*
p_y
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
)
override
void
*
resultSaveInvVariance
,
double
averageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleBiasMeanVarStrides
,
static_cast
<
const
InOutDataType
*>
(
p_x
),
static_cast
<
const
AccDataType
*>
(
bnScale
),
static_cast
<
const
AccDataType
*>
(
bnBias
),
static_cast
<
InOutDataType
*>
(
p_y
),
exponentialAverageFactor
,
static_cast
<
AccDataType
*>
(
resultRunningMean
),
static_cast
<
AccDataType
*>
(
resultRunningVariance
),
bnScaleStrides
,
bnBiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
ScaleDataType
*>
(
bnScale
),
static_cast
<
const
BiasDataType
*>
(
bnBias
),
epsilon
,
static_cast
<
AccDataType
*>
(
resultSaveMean
),
static_cast
<
AccDataType
*>
(
resultSaveInvVariance
));
y_elementwise_op
,
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
MeanVarDataType
*>
(
resultSaveMean
),
static_cast
<
MeanVarDataType
*>
(
resultSaveInvVariance
),
averageFactor
,
static_cast
<
MeanVarDataType
*>
(
resultRunningMean
),
static_cast
<
MeanVarDataType
*>
(
resultRunningVariance
));
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp
View file @
4fec5ad3
...
...
@@ -14,7 +14,12 @@ namespace ck {
namespace
tensor_operation
{
namespace
host
{
template
<
typename
InOutDataType
,
typename
AccDataType
>
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
>
struct
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormInfer
<
4
,
3
>
{
struct
Argument
:
public
device
::
BaseArgument
...
...
@@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarStrides
,
const
InOutDataType
*
p_x
,
const
AccDataType
*
bnScale
,
const
AccDataType
*
bnBias
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
double
epsilon
,
const
Acc
DataType
*
estimatedMean
,
const
Acc
DataType
*
estimatedVariance
,
InOut
DataType
*
p_y
)
const
MeanVar
DataType
*
estimatedMean
,
const
MeanVar
DataType
*
estimatedVariance
,
Y
DataType
*
p_y
)
:
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
...
...
@@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
estimatedVariance_
(
estimatedVariance
),
p_y_
(
p_y
)
{
(
void
)
xStrides
;
(
void
)
yStrides
;
(
void
)
bnScaleBiasMeanVarStrides
;
ignore
=
xStrides
;
ignore
=
yStrides
;
ignore
=
bnScaleStrides
;
ignore
=
bnBiasStrides
;
ignore
=
bnMeanVarStrides
;
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
n
=
xyLengths
[
0
];
h
=
xyLengths
[
1
];
w
=
xyLengths
[
2
];
c
=
xyLengths
[
3
];
n
_
=
xyLengths
[
0
];
h
_
=
xyLengths
[
1
];
w
_
=
xyLengths
[
2
];
c
_
=
xyLengths
[
3
];
}
const
InOut
DataType
*
p_x_
;
const
Acc
DataType
*
bnScale_
;
const
Acc
DataType
*
bnBias_
;
const
X
DataType
*
p_x_
;
const
Scale
DataType
*
bnScale_
;
const
Bias
DataType
*
bnBias_
;
double
epsilon_
;
const
Acc
DataType
*
estimatedMean_
;
const
Acc
DataType
*
estimatedVariance_
;
const
MeanVar
DataType
*
estimatedMean_
;
const
MeanVar
DataType
*
estimatedVariance_
;
InOut
DataType
*
p_y_
;
Y
DataType
*
p_y_
;
index_t
n
,
h
,
w
,
c
;
index_t
n
_
,
h
_
,
w
_
,
c
_
;
};
struct
Invoker
:
public
device
::
BaseInvoker
...
...
@@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
std
::
sqrt
(
type_convert
<
AccDataType
>
(
arg
.
epsilon_
)
+
variance
);
// Normalization
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
_
;
iN
++
)
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
;
iH
++
)
index_t
offset_N
=
iN
*
arg
.
h
_
*
arg
.
w
_
*
arg
.
c
_
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
_
;
iH
++
)
{
index_t
offset_H
=
iH
*
arg
.
w
*
arg
.
c
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
;
iW
++
)
index_t
offset_H
=
iH
*
arg
.
w
_
*
arg
.
c
_
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
_
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c
;
index_t
offset_W
=
iW
*
arg
.
c
_
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
...
...
@@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
AccDataType
norm_x
=
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
p_y_
[
offset
]
=
type_convert
<
InOut
DataType
>
(
norm_x
);
arg
.
p_y_
[
offset
]
=
type_convert
<
Y
DataType
>
(
norm_x
);
};
}
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
c
+
num_thread
-
1
)
/
num_thread
;
std
::
size_t
work_per_thread
=
(
arg
.
c
_
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c
);
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c
_
);
auto
f
=
[
=
]
{
for
(
std
::
size_t
ic
=
ic_begin
;
ic
<
ic_end
;
++
ic
)
...
...
@@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnBias
,
...
...
@@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
xStrides
,
yStrides
,
bnScaleBiasMeanVarLengths
,
bnScaleBiasMeanVarStrides
,
static_cast
<
const
InOutDataType
*>
(
p_x
),
static_cast
<
const
AccDataType
*>
(
bnScale
),
static_cast
<
const
AccDataType
*>
(
bnBias
),
bnScaleStrides
,
bnBiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
ScaleDataType
*>
(
bnScale
),
static_cast
<
const
BiasDataType
*>
(
bnBias
),
epsilon
,
static_cast
<
const
Acc
DataType
*>
(
estimatedMean
),
static_cast
<
const
Acc
DataType
*>
(
estimatedVariance
),
static_cast
<
InOut
DataType
*>
(
p_y
));
static_cast
<
const
MeanVar
DataType
*>
(
estimatedMean
),
static_cast
<
const
MeanVar
DataType
*>
(
estimatedVariance
),
static_cast
<
Y
DataType
*>
(
p_y
));
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
4fec5ad3
...
...
@@ -3,7 +3,10 @@
#pragma once
#include <cstdlib>
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -15,6 +18,8 @@ using F64 = double;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
View file @
4fec5ad3
...
...
@@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
false
>>>&
instances
);
void
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemm
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
true
>>>&
instances
);
template
<
typename
ALayout
,
typename
B0Layout
,
...
...
@@ -39,7 +56,8 @@ template <typename ALayout,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
>
typename
CDataType
,
bool
MaskOutUpperTriangle
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
...
...
@@ -51,9 +69,10 @@ struct DeviceOperationInstanceFactory<
CDataType
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
P
as
sThrough
>>
M
as
kOutUpperTriangle
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
...
...
@@ -65,9 +84,10 @@ struct DeviceOperationInstanceFactory<
CDataType
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
P
as
sThrough
>
;
M
as
kOutUpperTriangle
>
;
static
auto
GetInstances
()
{
...
...
@@ -79,8 +99,16 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
if
constexpr
(
MaskOutUpperTriangle
)
{
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
}
else
{
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
}
}
}
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_
masking_scale_
softmax_gemm_permute.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
4fec5ad3
...
...
@@ -17,63 +17,89 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CPermuteNumDims_G_M_Gemm1N
,
typename
ADataType
,
template
<
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
>
typename
CDataType
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_Gemm1N
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>>
PassThrough
,
MaskingSpec
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_Gemm1N
,
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>
;
PassThrough
,
MaskingSpec
>
;
static
auto
GetInstances
()
{
...
...
@@ -82,11 +108,14 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
add_device_batched_gemm_
masking_scale_
softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
s
(
op_ptrs
);
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp
View file @
4fec5ad3
...
...
@@ -3,24 +3,77 @@
#pragma once
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp"
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
View file @
4fec5ad3
...
...
@@ -5,6 +5,8 @@
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp"
namespace
ck
{
...
...
@@ -63,33 +65,20 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>
;
#endif
template
<
ReduceTensorOp
ReduceOpId
>
using
deviceReduceBlockWisePtrType
=
DeviceReducePtr
<
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
>
;
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
int
Rank
,
int
NumReduceDim
,
ReduceTensorOp
ReduceOpId
,
typename
ReduceOperation
,
typename
InElementwiseOp
,
typename
AccElementwiseOp
,
bool
PropagateNan
,
bool
Use
Index
>
bool
Output
Index
>
void
add_device_reduce_instance_blockwise
(
std
::
vector
<
deviceReduceBlockWisePtrType
<
ReduceOpId
>>&
device_op_instances
)
std
::
vector
<
DeviceReducePtr
<
Rank
,
NumReduceDim
,
InElementwiseOp
,
AccElementwiseOp
>>&
device_op_instances
)
{
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
constexpr
bool
Indexable
=
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
::
AMAX
);
constexpr
bool
OutputIndex
=
Indexable
&&
UseIndex
;
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_1_instances_blockwise
>::
value
,
1
>
{}(
[
&
](
auto
i
)
{
using
cfg1
=
remove_cvref_t
<
decltype
(
...
...
@@ -107,8 +96,8 @@ void add_device_reduce_instance_blockwise(
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOp
eration
,
AccElementwiseOp
eration
,
InElementwiseOp
,
AccElementwiseOp
,
InMemoryDataOperationEnum
::
Set
,
PropagateNan
,
OutputIndex
,
...
...
@@ -128,52 +117,6 @@ void add_device_reduce_instance_blockwise(
});
};
#define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<bool>(NanOpt), \
static_cast<bool>(IndicesOpt), \
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<bool>(NanOpt), \
static_cast<bool>(IndicesOpt), \
Rank, \
NumReduceDim)
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
deleted
100644 → 0
View file @
24faa1fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
0
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
0
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
5
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
5
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
5
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
7
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
7
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
7
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
3
,
PassThrough
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
4
,
PassThrough
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
1
,
PassThrough
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
2
,
1
,
PassThrough
,
PassThrough
>>&
);
// clang-format on
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
3
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
4
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
1
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
2
,
1
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
4
,
3
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
4
,
4
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
4
,
1
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
2
,
1
,
UnaryAbs
,
PassThrough
>>&
);
// clang-format on
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment