Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7bcaf2a7
Unverified
Commit
7bcaf2a7
authored
Dec 19, 2022
by
Adam Osewski
Committed by
GitHub
Dec 19, 2022
Browse files
Merge branch 'develop' into wavelet_model
parents
e59daa22
0345963e
Changes
188
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2006 additions
and
890 deletions
+2006
-890
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
...k/tensor_operation/gpu/element/quantization_operation.hpp
+57
-19
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
...ultiblock_reduce_second_half_batchnorm_backward_final.hpp
+45
-81
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
...lock_welford_second_half_multiblock_reduce_first_half.hpp
+19
-38
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
...pu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
+31
-49
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+230
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
...tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
+678
-0
include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_first_half.hpp
...ation/gpu/grid/gridwise_multiblock_welford_first_half.hpp
+0
-258
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+102
-0
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+16
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp
...nce_tensor_operation/cpu/reference_batchnorm_backward.hpp
+412
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp
...sor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp
+0
-319
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+13
-5
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp
...rary/tensor_operation_instance/gpu/batchnorm_backward.hpp
+124
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+44
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
...operation_instance/gpu/grouped_convolution_forward_dl.hpp
+0
-116
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
...uped_convolution_bias_forward_perchannel_quantization.hpp
+114
-0
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
...rouped_convolution_bias_forward_perlayer_quantization.hpp
+3
-3
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
...n/grouped_convolution_forward_perchannel_quantization.hpp
+113
-0
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp
...ion/grouped_convolution_forward_perlayer_quantization.hpp
+0
-0
No files found.
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
View file @
7bcaf2a7
...
...
@@ -10,8 +10,8 @@ namespace element_wise {
template
<
typename
Activation
>
struct
Activation_Mul_Clamp
{
Activation_Mul_Clamp
(
float
multiplier
,
Activation
activationOp
)
:
multiplier_
(
multiplier
),
activationOp_
(
activationOp
)
Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
{
}
...
...
@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp
{
float
x_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
x_fp32
,
x_fp32
);
float
y_fp32
=
math
::
clamp
(
multiplier
_
*
x_fp32
,
-
128.
f
,
127.
f
);
float
y_fp32
=
math
::
clamp
(
requantScale
_
*
x_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
...
...
@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp
// We might type_convert to int8 after lambda in someplace
float
x_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
x_fp32
,
x_fp32
);
y
=
math
::
clamp
(
multiplier_
*
x_fp32
,
-
128.
f
,
127.
f
);
y
=
math
::
clamp
(
requantScale_
*
x_fp32
,
-
128.
f
,
127.
f
);
}
float
requantScale_
;
Activation
activationOp_
;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template
<
typename
Activation
>
struct
Activation_Mul2_Clamp
{
Activation_Mul2_Clamp
(
Activation
activationOp
)
:
activationOp_
(
activationOp
)
{}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
float
&
requantScale
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
multiplier_
;
Activation
activationOp_
;
};
...
...
@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp
template
<
typename
Activation
>
struct
Add_Activation_Mul_Clamp
{
Add_Activation_Mul_Clamp
(
float
multiplier
,
Activation
activationOp
)
:
multiplier_
(
multiplier
),
activationOp_
(
activationOp
)
Add_Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x1
,
const
int32_t
&
x2
)
const
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
requantScale_
;
Activation
activationOp_
;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template
<
typename
Activation
>
struct
Add_Activation_Mul2_Clamp
{
Add_Activation_Mul2_Clamp
(
Activation
activationOp
)
:
activationOp_
(
activationOp
)
{}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
,
const
float
&
requantScale
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
1
+
x2
);
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
multiplier_
*
y_fp32
,
-
128.
f
,
127.
f
);
y_fp32
=
math
::
clamp
(
requantScale
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
multiplier_
;
Activation
activationOp_
;
};
...
...
@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp
template
<
typename
Activation
>
struct
Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp
(
float
multiplier1
,
float
multiplier
2
,
Activation
activationOp
)
:
multiplier1_
(
multiplier1
),
multiplier2_
(
multiplier
2
),
activationOp_
(
activationOp
)
Add_Mul_Activation_Mul_Clamp
(
float
requantScale1
,
float
requantScale
2
,
Activation
activationOp
)
:
requantScale1_
(
requantScale1
),
requantScale2_
(
requantScale
2
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
1
,
const
int32_t
&
x2
)
const
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
1
+
x2
);
y_fp32
=
multiplier
1_
*
y_fp32
;
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
requantScale
1_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
multiplier
2_
*
y_fp32
,
-
128.
f
,
127.
f
);
y_fp32
=
math
::
clamp
(
requantScale
2_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
multiplier
1_
;
float
multiplier
2_
;
float
requantScale
1_
;
float
requantScale
2_
;
Activation
activationOp_
;
};
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
View file @
7bcaf2a7
...
...
@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename
DyDataType
,
typename
DxDataType
,
typename
ScaleDataType
,
typename
B
iasDataType
,
typename
DscaleDb
iasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
...
...
@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t
reduce_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_dscale_dbias_k_block_tile_iteration
,
const
S
caleDataType
*
const
__restrict__
p_reduce_dscale
,
const
B
iasDataType
*
const
__restrict__
p_reduce_dbias
,
const
Ds
caleD
biasD
ataType
*
const
__restrict__
p_reduce_dscale
,
const
DscaleDb
iasDataType
*
const
__restrict__
p_reduce_dbias
,
const
MeanVarDataType
*
const
__restrict__
p_mean
,
const
MeanVarDataType
*
const
__restrict__
p_inv_var
,
const
XDataType
*
const
__restrict__
p_x
,
...
...
@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
S
caleDataType
*
const
__restrict__
p_dscale
,
B
iasDataType
*
const
__restrict__
p_dbias
)
Ds
caleD
biasD
ataType
*
const
__restrict__
p_dscale
,
DscaleDb
iasDataType
*
const
__restrict__
p_dbias
)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
...
...
@@ -76,7 +76,7 @@ template <typename XDataType,
typename
DxDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
B
iasDataType
,
typename
DscaleDb
iasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
...
...
@@ -92,8 +92,8 @@ template <typename XDataType,
index_t
XSrcVectorSize
,
index_t
DySrcVectorSize
,
index_t
DxDstVectorSize
,
index_t
ScaleSrc
Dst
VectorSize
,
index_t
B
iasDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
DscaleDb
iasDstVectorSize
,
index_t
MeanVarSrcVectorSize
>
struct
GridwiseReduceSecondHalfBatchNormBackwardFinal
{
...
...
@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const
DscaleDbiasGridDesc_M_K
&
dscale_dbias_grid_desc_m_k
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
dscale_d
bias_grid_desc_m
,
index_t
blkgroup_size
,
long_index_t
reduce_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_dscale_dbias_k_block_tile_iteration
,
const
S
caleDataType
*
const
__restrict__
p_reduce_dscale
,
const
B
iasDataType
*
const
__restrict__
p_reduce_dbias
,
const
Ds
caleD
biasD
ataType
*
const
__restrict__
p_reduce_dscale
,
const
DscaleDb
iasDataType
*
const
__restrict__
p_reduce_dbias
,
const
MeanVarDataType
*
const
__restrict__
p_mean
,
const
MeanVarDataType
*
const
__restrict__
p_inv_var
,
const
XDataType
*
const
__restrict__
p_x
,
...
...
@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
S
caleDataType
*
const
__restrict__
p_dscale
,
B
iasDataType
*
const
__restrict__
p_dbias
)
Ds
caleD
biasD
ataType
*
const
__restrict__
p_dscale
,
DscaleDb
iasDataType
*
const
__restrict__
p_dbias
)
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
...
...
@@ -222,24 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto
threadwise_dscale_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
DscaleDbiasGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
dscale_dbias_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_dbias_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
auto
threadwise_dscale_dbias_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
DscaleDbiasDataType
,
AccDataType
,
DscaleDbiasGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
...
...
@@ -254,38 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_dscale_store_m
=
auto
threadwise_dscale_
dbias_
store_m
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
S
caleDataType
,
Ds
caleD
biasD
ataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
S
cale
Src
DstVectorSize
,
Ds
cale
Dbias
DstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_dbias_store_m
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
BiasDataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
bias_grid_desc_m
,
dscale_dbias_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
...
...
@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
p_reduce_dbias
,
dscale_dbias_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dscale_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dscale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
p_dscale
,
d
scale_
dbias_
grid_desc_m
.
GetElementSpaceSize
());
auto
dbias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dbias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
p_dbias
,
dscale_d
bias_grid_desc_m
.
GetElementSpaceSize
());
constexpr
auto
dscale_dbias_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
...
...
@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_dscale_dbias_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dscale_load_m_k
.
Run
(
dscale_dbias_grid_desc_m_k
,
reduce_dscale_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dscale_thread_buf
);
threadwise_dbias_load_m_k
.
Run
(
dscale_dbias_grid_desc_m_k
,
reduce_dbias_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dbias_thread_buf
);
threadwise_dscale_
dbias_
load_m_k
.
Run
(
dscale_dbias_grid_desc_m_k
,
reduce_dscale_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dscale_thread_buf
);
threadwise_
dscale_
dbias_load_m_k
.
Run
(
dscale_dbias_grid_desc_m_k
,
reduce_dbias_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dbias_thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_dscale_thread_buf
,
dscale_thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_dbias_thread_buf
,
dbias_thread_buf
);
threadwise_dscale_load_m_k
.
MoveSrcSliceWindow
(
dscale_dbias_grid_desc_m_k
,
dscale_dbias_thread_copy_step_m_k
);
threadwise_dbias_load_m_k
.
MoveSrcSliceWindow
(
dscale_dbias_grid_desc_m_k
,
dscale_dbias_thread_copy_step_m_k
);
threadwise_dscale_dbias_load_m_k
.
MoveSrcSliceWindow
(
dscale_dbias_grid_desc_m_k
,
dscale_dbias_thread_copy_step_m_k
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
dbias_thread_buf
(
I
));
});
threadwise_dscale_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dscale_thread_buf
,
scale
_grid_desc_m
,
dscale_global_buf
);
threadwise_dscale_
dbias_
store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dscale_thread_buf
,
dscale_dbias
_grid_desc_m
,
dscale_global_buf
);
threadwise_dbias_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbias_thread_buf
,
bias_grid_desc_m
,
dbias_global_buf
);
threadwise_
dscale_
dbias_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbias_thread_buf
,
dscale_d
bias_grid_desc_m
,
dbias_global_buf
);
// clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
...
...
@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrc
Dst
VectorSize
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
View file @
7bcaf2a7
...
...
@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
B
iasDataType
,
typename
DscaleDb
iasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
...
...
@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
S
caleDataType
*
const
__restrict__
p_reduce_dscale
,
B
iasDataType
*
const
__restrict__
p_reduce_dbias
)
Ds
caleD
biasD
ataType
*
const
__restrict__
p_reduce_dscale
,
DscaleDb
iasDataType
*
const
__restrict__
p_reduce_dbias
)
{
GridwiseWelfordSecondHalfReduceFirstHalf_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
...
...
@@ -76,7 +76,7 @@ template <typename XDataType,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
B
iasDataType
,
typename
DscaleDb
iasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
...
...
@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
S
caleDataType
*
const
__restrict__
p_reduce_dscale
,
B
iasDataType
*
const
__restrict__
p_reduce_dbias
)
Ds
caleD
biasD
ataType
*
const
__restrict__
p_reduce_dscale
,
DscaleDb
iasDataType
*
const
__restrict__
p_reduce_dbias
)
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
...
...
@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
reduce_dbias_thread_buf
(
I
));
});
auto
threadwise_dscale_store
=
auto
threadwise_dscale_
dbias_
store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ScaleDataType
,
decltype
(
thread_buffer_desc_m_1
),
DscaleDbiasGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dscale_dbias_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_dbias_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
BiasDataType
,
DscaleDbiasDataType
,
decltype
(
thread_buffer_desc_m_1
),
DscaleDbiasGridDesc_M_G
,
PassThroughOp
,
...
...
@@ -557,17 +538,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if
(
thread_k_cluster_id
==
0
)
{
threadwise_dscale_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dscale_thread_buf
,
dscale_dbias_grid_desc_m_g
,
reduce_dscale_global_buf
);
threadwise_dbias_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dbias_thread_buf
,
dscale_dbias_grid_desc_m_g
,
reduce_dbias_global_buf
);
threadwise_dscale_
dbias_
store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dscale_thread_buf
,
dscale_dbias_grid_desc_m_g
,
reduce_dscale_global_buf
);
threadwise_
dscale_
dbias_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dbias_thread_buf
,
dscale_dbias_grid_desc_m_g
,
reduce_dbias_global_buf
);
};
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
7bcaf2a7
...
...
@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
});
}
else
{
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
View file @
7bcaf2a7
...
...
@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename
DxDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
B
iasDataType
,
typename
DscaleDb
iasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
...
...
@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
dx_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
ScaleBiasGridDesc_M
dscale_d
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
long_index_t
reduce_size
,
...
...
@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
S
caleDataType
*
const
__restrict__
p_dscale
,
B
iasDataType
*
const
__restrict__
p_dbias
)
Ds
caleD
biasD
ataType
*
const
__restrict__
p_dscale
,
DscaleDb
iasDataType
*
const
__restrict__
p_dbias
)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
dx_grid_desc_m_k
,
scale_grid_desc_m
,
bias_grid_desc_m
,
dscale_d
bias_grid_desc_m
,
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
reduce_size
,
...
...
@@ -77,7 +77,7 @@ template <typename XDataType,
typename
DxDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
B
iasDataType
,
typename
DscaleDb
iasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
...
...
@@ -93,8 +93,8 @@ template <typename XDataType,
index_t
XSrcVectorSize
,
index_t
DySrcVectorSize
,
index_t
DxDstVectorSize
,
index_t
ScaleSrc
Dst
VectorSize
,
index_t
B
iasDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
DscaleDb
iasDstVectorSize
,
index_t
MeanVarSrcVectorSize
>
struct
GridwiseBatchNormBackwardWithBlockwiseWelford
{
...
...
@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
dx_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
ScaleBiasGridDesc_M
dscale_d
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
long_index_t
reduce_size
,
...
...
@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
S
caleDataType
*
const
__restrict__
p_dscale
,
B
iasDataType
*
const
__restrict__
p_dbias
)
Ds
caleD
biasD
ataType
*
const
__restrict__
p_dscale
,
DscaleDb
iasDataType
*
const
__restrict__
p_dbias
)
{
using
ck
::
math
::
sqrt
;
...
...
@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
XSrcVectorSize
,
1
,
true
>
(
x
_grid_desc_m_k
,
dy
_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
...
...
@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
d
y
_grid_desc_m_k
,
d
x
_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
...
...
@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrc
Dst
VectorSize
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_dscale_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ScaleDataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_dbias_store
=
auto
threadwise_dscale_dbias_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
B
iasDataType
,
DscaleDb
iasDataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
B
iasDstVectorSize
,
DscaleDb
iasDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
bias_grid_desc_m
,
dscale_d
bias_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
...
...
@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
dscale_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dscale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
p_dscale
,
d
scale_
dbias_
grid_desc_m
.
GetElementSpaceSize
());
auto
dbias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dbias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
p_dbias
,
dscale_d
bias_grid_desc_m
.
GetElementSpaceSize
());
// clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
...
...
@@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
if
(
thread_k_cluster_id
==
0
)
{
threadwise_dscale_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dscale_thread_buf
,
scale
_grid_desc_m
,
dscale_global_buf
);
threadwise_dbias_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbias_thread_buf
,
bias_grid_desc_m
,
dbias_global_buf
);
threadwise_dscale_
dbias_
store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dscale_thread_buf
,
dscale_dbias
_grid_desc_m
,
dscale_global_buf
);
threadwise_
dscale_
dbias_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbias_thread_buf
,
dscale_d
bias_grid_desc_m
,
dbias_global_buf
);
};
// clang-format off
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
0 → 100644
View file @
7bcaf2a7
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseElementwise2dFunctor
,
typename
InGrid2dDescTuple
,
typename
OutGrid2dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
>
__global__
void
kernel_elementwise_2d
(
const
InGrid2dDescTuple
in_grid_2d_desc_tuple
,
const
OutGrid2dDescTuple
out_grid_2d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
num_threads_m
,
const
index_t
num_threads_n
)
{
GridwiseElementwise2dFunctor
::
Run
(
in_grid_2d_desc_tuple
,
out_grid_2d_desc_tuple
,
p_in_global_tuple
,
p_out_global_tuple
,
elementwise_op
,
num_threads_m
,
num_threads_n
);
}
template
<
typename
InGrid2dDescTuple
,
typename
OutGrid2dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
,
index_t
MPerThread
,
index_t
NPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
GridwiseElementwise_2D
{
static
constexpr
index_t
NumInput
=
InDataTypePointerTuple
::
Size
();
static
constexpr
index_t
NumOutput
=
OutDataTypePointerTuple
::
Size
();
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
()
&&
NumInput
==
InGrid2dDescTuple
::
Size
()
&&
NumOutput
==
OutGrid2dDescTuple
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
thread_buffer_desc_mn
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
NPerThread
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
__device__
static
void
Run
(
const
InGrid2dDescTuple
in_grid_2d_desc_tuple
,
const
OutGrid2dDescTuple
out_grid_2d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
num_threads_m
,
const
index_t
num_threads_n
)
{
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
*
NPerThread
,
true
>
{};
},
Number
<
NumInput
>
{});
auto
out_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
*
NPerThread
,
true
>
{};
},
Number
<
NumOutput
>
{});
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_2d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumInput
>
{});
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_2d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumOutput
>
{});
const
auto
M
=
in_grid_2d_desc_tuple
[
I0
].
GetLength
(
I0
);
const
auto
N
=
in_grid_2d_desc_tuple
[
I0
].
GetLength
(
I1
);
const
index_t
loop_step_m
=
num_threads_m
*
MPerThread
;
const
index_t
loop_step_n
=
num_threads_n
*
NPerThread
;
const
index_t
thread_1d_id
=
get_thread_global_1d_id
();
index_t
tid_m
=
thread_1d_id
/
num_threads_n
;
index_t
tid_n
=
thread_1d_id
%
num_threads_n
;
const
auto
thread_global_offset
=
make_multi_index
(
tid_m
*
MPerThread
,
tid_n
*
NPerThread
);
auto
in_global_load_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
in_grid_2d_desc_tuple
[
I
]),
decltype
(
thread_buffer_desc_mn
),
Sequence
<
MPerThread
,
NPerThread
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
0
,
// SrcVectorDim
InScalarPerVectorSeq
::
At
(
I
),
// ScalarPerVector
1
,
// SrcScalarStrideInVector
true
>
{
in_grid_2d_desc_tuple
[
I
],
thread_global_offset
};
},
Number
<
NumInput
>
{});
auto
out_global_store_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
ThreadwiseTensorSliceTransfer_v1r3
<
DataType
,
DataType
,
decltype
(
thread_buffer_desc_mn
),
decltype
(
out_grid_2d_desc_tuple
[
I
]),
PassThroughOp
,
Sequence
<
MPerThread
,
NPerThread
>
,
// SliceLengths
Sequence
<
0
,
1
>
,
// DimAccessOrder
1
,
// SrcVectorDim
1
,
// OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_2d_desc_tuple
[
I
],
thread_global_offset
,
PassThroughOp
{});
},
Number
<
NumOutput
>
{});
index_t
num_iter_m
=
M
/
(
loop_step_m
);
do
{
index_t
num_iter_n
=
N
/
(
loop_step_n
);
do
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
Run
(
in_grid_2d_desc_tuple
[
I
],
in_global_buf_tuple
[
I
],
thread_buffer_desc_mn
,
make_tuple
(
I0
,
I0
),
in_thread_buf_tuple
(
I
));
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_2d_desc_tuple
[
I
],
make_multi_index
(
0
,
loop_step_n
));
});
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
NPerThread
,
1
>
{}([
&
](
auto
iN
)
{
constexpr
auto
offset
=
thread_buffer_desc_mn
.
CalculateOffset
(
make_tuple
(
iM
,
iN
));
// get reference to in data
const
auto
in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
const
auto
&
{
return
in_thread_buf_tuple
(
I
)(
Number
<
offset
>
{});
},
Number
<
NumInput
>
{});
// get referenec to dst data
auto
out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
out_thread_buf_tuple
(
I
)(
Number
<
offset
>
{});
},
Number
<
NumOutput
>
{});
unpack2
(
elementwise_op
,
out_data_refs
,
in_data_refs
);
});
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
Run
(
thread_buffer_desc_mn
,
make_tuple
(
I0
,
I0
),
out_thread_buf_tuple
[
I
],
out_grid_2d_desc_tuple
[
I
],
out_global_buf_tuple
(
I
));
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_2d_desc_tuple
[
I
],
make_multi_index
(
0
,
loop_step_n
));
});
}
while
(
--
num_iter_n
);
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_2d_desc_tuple
[
I
],
make_multi_index
(
loop_step_m
,
-
(
N
/
loop_step_n
)
*
loop_step_n
));
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_2d_desc_tuple
[
I
],
make_multi_index
(
loop_step_m
,
-
(
N
/
loop_step_n
)
*
loop_step_n
));
});
}
while
(
--
num_iter_m
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
0 → 100644
View file @
7bcaf2a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
DsDataType
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1Value
,
index_t
M1PerThreadM111
,
index_t
N1PerThreadN111
,
index_t
KPerThread
,
typename
M11N11ThreadClusterM110Xs
,
typename
M11N11ThreadClusterN110Xs
,
typename
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
typename
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
typename
BBlockTransferSrcVectorTensorContiguousDimOrder
,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemmDlMultipleD_km_kn_mn
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
+
b_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
{
const
bool
has_main_k_block_loop
=
(
K0
+
K0PerBlock
)
/
(
2
*
K0PerBlock
)
>
1
;
return
has_main_k_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasDoubleTailKBlockLoop
(
index_t
K0
)
{
const
bool
has_double_tail_k_block_loop
=
(
K0
/
K0PerBlock
)
%
2
==
0
;
return
has_double_tail_k_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_M0_M1_K1
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
)
{
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
a_grid_desc_k0_m0_m1_k1
=
transform_tensor_descriptor
(
a_grid_desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_grid_desc_k0_m0_m1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_N0_N1_K1
(
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
)
{
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
b_grid_desc_k0_n0_n1_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_grid_desc_k0_n0_n1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies
{},
I1
)
*
M1PerThreadM111
>
{};
constexpr
auto
N11
=
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies
{},
I1
)
*
N1PerThreadN111
>
{};
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_grid_desc_m0_m10_m11_n0_n10_n11
;
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
typename
DsGridDesc_M0_M10_M11_N0_N10_N11
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AElementwiseOperation
&
,
const
BElementwiseOperation
&
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
DsGridDesc_M0_M10_M11_N0_N10_N11
&
ds_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
&
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
make_tuple
(
im0
,
in0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
{
return
;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
()
==
b_k0_n_k1_block_desc
.
GetElementSpaceSize
()
&&
"wrong!"
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
MPerBlock
,
K1
.
value
>
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
decltype
(
a_block_desc_k0_m0_m1_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
// SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
// DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
a_grid_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
im0
,
0
,
0
),
a_block_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
NPerBlock
,
K1
.
value
>
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
decltype
(
b_block_desc_k0_n0_n1_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
// SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
// DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
b_grid_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
in0
,
0
,
0
),
b_block_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const
auto
blockwise_gemm
=
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
M1PerThreadM111
,
N1PerThreadN111
,
KPerThread
,
M11N11ThreadClusterM110Xs
,
M11N11ThreadClusterN110Xs
,
M1PerThreadM111
,
N1PerThreadN111
>
{};
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_thread_desc_m10_m11_n10_n11
=
make_naive_tensor_descriptor_packed
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_thread_desc_m10_m11_n10_n11
.
GetElementSpaceSize
());
// Initialize C
c_thread_buf
.
Clear
();
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
{
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
block_sync_lds
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_m10_m11_n0_n10_n11
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}));
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
auto
ds_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DDataType
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
],
true
>
{};
},
Number
<
NumDTensor
>
{});
auto
ds_threadwise_copy
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
DDataType
,
decltype
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
]),
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
Sequence
<
I1
,
I1
,
I1
,
I1
,
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
1
,
false
>
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]));
},
Number
<
NumDTensor
>
{});
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
1
>
{}([
&
](
auto
m10
)
{
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
>
{}([
&
](
auto
m11
)
{
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
1
>
{}([
&
](
auto
n10
)
{
// load d matrix data
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
Run
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
ds_grid_buf
[
i
],
c_thread_desc_m0_m10_m11_n0_n10_n11
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
ds_thread_buf
(
i
));
});
// cal element op
static_for
<
0
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
],
1
>
{}(
[
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
ds_thread_buf
[
iSrc
][
i
];
},
Number
<
NumDTensor
>
{});
// get reference to dst data
constexpr
index_t
c_offset
=
c_thread_desc_m0_m10_m11_n0_n10_n11
.
CalculateOffset
(
make_tuple
(
0
,
m10
,
m11
,
0
,
n10
,
i
));
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
)
->
auto
&
{
return
c_thread_buf
(
Number
<
c_offset
>
{});
},
Number
<
2
>
{});
unpack2
(
cde_element_op
,
dst_data_refs
,
src_data_refs
);
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
0
,
0
,
0
,
1
,
0
));
});
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
0
,
1
,
0
,
-
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
0
));
});
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
ds_grid_desc_m0_m10_m11_n0_n10_n11
[
i
],
make_multi_index
(
0
,
1
,
-
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
0
,
0
,
0
));
});
});
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
decltype
(
c_grid_desc_m0_m10_m11_n0_n10_n11
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_m10_m11_n0_n10_n11
,
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}}
.
Run
(
c_thread_desc_m0_m10_m11_n0_n10_n11
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_first_half.hpp
deleted
100644 → 0
View file @
e59daa22
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseMultiblockWelfordFirstHalf_
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_multiblock_welford_first_half
(
const
XGridDesc_M_K
x_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
mean_var_count_grid_desc_m_g
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
p_welford_mean
,
MeanVarDataType
*
const
p_welford_variance
,
int32_t
*
const
p_welford_count
)
{
GridwiseMultiblockWelfordFirstHalf_
::
Run
(
x_grid_desc_m_k
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
num_k_block_tile_iteration
,
p_x
,
p_welford_mean
,
p_welford_variance
,
p_welford_count
);
};
template
<
typename
XDataType
,
typename
AccDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcCountSrcVectorDim
,
index_t
XSrcCountSrcVectorSize
>
struct
GridwiseMultiblockWelfordFirstHalf
{
static_assert
((
XSrcCountSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcCountSrcVectorSize
==
0
)
||
(
XSrcCountSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcCountSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcCountSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XGridDesc_M_K
&
x_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
&
mean_var_count_grid_desc_m_g
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
p_welford_mean
,
MeanVarDataType
*
const
p_welford_variance
,
int32_t
*
const
p_welford_count
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
const
index_t
blkgroup_size
=
mean_var_count_grid_desc_m_g
.
GetLength
(
I1
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcCountSrcVectorDim
,
XSrcCountSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_welford_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_welford_count_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_mean
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_variance
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
get_reduce_count_per_thread
(
block_local_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
,
welford_mean_thread_buf
,
welford_var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
welford_count_thread_buf
(
I
)
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
welford_mean_thread_buf
(
I
),
welford_var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_mean_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_mean_global_val_buf
);
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_var_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_var_global_val_buf
);
threadwise_welford_count_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_count_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_count_global_val_buf
);
};
}
};
}
// namespace ck
include/ck/utility/amd_wmma.hpp
0 → 100644
View file @
7bcaf2a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "data_type.hpp"
// TODO: Add arch limitation
namespace
ck
{
// wave32 only
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
// src: fp16, dst: fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w32
;
template
<
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: bf16, dst: bf16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
;
template
<
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
neg_a
,
bit_cast
<
int32x4_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
}
};
}
// namespace ck
#endif
include/ck/utility/math_v2.hpp
View file @
7bcaf2a7
...
...
@@ -114,7 +114,16 @@ static inline __device__ int4_t abs(int4_t x)
};
#endif
static
inline
__device__
half_t
abs
(
half_t
x
)
{
return
::
__habs
(
x
);
};
static
inline
__device__
half_t
abs
(
half_t
x
)
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
half_t
abs_x
=
ck
::
bit_cast
<
half_t
>
(
abs_xx
);
return
abs_x
;
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
...
...
@@ -140,7 +149,12 @@ static inline __device__ bool isnan(int4_t x)
};
#endif
static
inline
__device__
bool
isnan
(
half_t
x
)
{
return
::
__hisnan
(
x
);
};
static
inline
__device__
bool
isnan
(
half_t
x
)
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
::
sqrtf
(
x
);
};
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp
0 → 100644
View file @
7bcaf2a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
XDataType
,
typename
DxDataType
,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
DscaleDbiasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
ReferenceBatchNormBwd
:
public
device
::
DeviceBatchNormBwd
<
XDataType
,
DxDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
DyDataType
*
p_dy
,
const
ScaleDataType
*
p_scale
,
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedInvVar
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
p_dx
,
DscaleDbiasDataType
*
p_dscale
,
DscaleDbiasDataType
*
p_dbias
)
:
reduceDims_
(
reduceDims
),
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnDscaleDbiasStrides_
(
bnDscaleDbiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
p_dy_
(
p_dy
),
p_scale_
(
p_scale
),
p_savedMean_
(
p_savedMean
),
p_savedInvVar_
(
p_savedInvVar
),
dy_elementwise_op_
(
dy_elementwise_op
),
p_dx_
(
p_dx
),
p_dscale_
(
p_dscale
),
p_dbias_
(
p_dbias
)
{
using
ck
::
host_common
::
get_index_set
;
if
(
std
::
any_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
// get invariant_dims[] and invariant_lengths[]
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
if
(
std
::
none_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
{
invariantDims_
[
i
]
=
dim
;
invariant_lengths_
[
i
]
=
xyLengths
[
dim
];
i
++
;
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
xyLengths
[
dim
];
};
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
bnScaleBiasMeanVarLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
x_invariant_strides_
[
i
]
=
xStrides
[
dim
];
dy_invariant_strides_
[
i
]
=
dyStrides
[
dim
];
dx_invariant_strides_
[
i
]
=
dxStrides
[
dim
];
i
++
;
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
x_reduce_strides_
[
i
]
=
xStrides
[
dim
];
dy_reduce_strides_
[
i
]
=
dyStrides
[
dim
];
dx_reduce_strides_
[
i
]
=
dxStrides
[
dim
];
i
++
;
};
reduceSize_
=
std
::
accumulate
(
reduce_lengths_
.
begin
(),
reduce_lengths_
.
end
(),
1
,
std
::
multiplies
<
size_t
>
{});
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumBatchNormReduceDim
>
(
reduce_lengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
haveSavedMeanInvVar_
=
(
p_savedMean
!=
nullptr
&&
p_savedInvVar
!=
nullptr
);
}
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
x_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
dy_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
dx_invariant_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
x_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
dy_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
dx_reduce_strides_
;
const
XDataType
*
p_x_
;
const
DyDataType
*
p_dy_
;
const
ScaleDataType
*
p_scale_
;
const
MeanVarDataType
*
p_savedMean_
;
const
MeanVarDataType
*
p_savedInvVar_
;
const
DyElementwiseOp
dy_elementwise_op_
;
DxDataType
*
p_dx_
;
DscaleDbiasDataType
*
p_dscale_
;
DscaleDbiasDataType
*
p_dbias_
;
bool
haveSavedMeanInvVar_
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumBatchNormReduceDim
>>
reduce_index_set_
;
AccDataType
epsilon_
;
size_t
reduceSize_
;
};
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
using
ck
::
host_common
::
get_offset_from_index
;
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
size_t
x_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
x_invariant_strides_
,
invariant_index
);
size_t
dy_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
dy_invariant_strides_
,
invariant_index
);
size_t
dx_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
dx_invariant_strides_
,
invariant_index
);
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
invVar
;
int32_t
curr_count
=
0
;
if
(
arg
.
haveSavedMeanInvVar_
)
{
size_t
mean_invVar_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
mean
=
type_convert
<
AccDataType
>
(
arg
.
p_savedMean_
[
mean_invVar_invariant_offset
]);
invVar
=
type_convert
<
AccDataType
>
(
arg
.
p_savedInvVar_
[
mean_invVar_invariant_offset
]);
}
else
{
// compute mean, variance using welford method
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_offset
]);
AccDataType
delta
=
x
-
mean
;
mean
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
};
// actual variance
variance
=
variance
/
curr_count
;
// inv-variance defined as 1/sqrt(epsilon+variance)
invVar
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
};
AccDataType
dbias
=
type_convert
<
AccDataType
>
(
0.0
f
);
// Sum on reduced dimensions of dy
AccDataType
dscale
=
type_convert
<
AccDataType
>
(
0.0
f
);
// Sum on reduced dimensions of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on reduced dimensions
// 3) calculate sum(dy * norm_x) on reduced dimensions
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
size_t
dy_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
dy_reduce_strides_
,
reduce_index
);
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
dy_offset
=
dy_invariant_offset
+
dy_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_offset
]);
AccDataType
norm_x
=
(
x
-
mean
)
*
invVar
;
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
dy_offset
]);
arg
.
dy_elementwise_op_
(
dy
,
dy
);
dbias
+=
dy
;
dscale
+=
norm_x
*
dy
;
};
size_t
dscale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnDscaleDbiasStrides_
,
invariant_index
);
size_t
dbias_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnDscaleDbiasStrides_
,
invariant_index
);
arg
.
p_dscale_
[
dscale_offset
]
=
type_convert
<
DscaleDbiasDataType
>
(
dscale
);
arg
.
p_dbias_
[
dbias_offset
]
=
type_convert
<
DscaleDbiasDataType
>
(
dbias
);
size_t
scale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnScaleStrides_
,
invariant_index
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
p_scale_
[
scale_offset
]);
AccDataType
multiplier
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
type_convert
<
AccDataType
>
(
arg
.
reduceSize_
)
*
invVar
*
scale
;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias
// - tmp)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
size_t
dy_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
dy_reduce_strides_
,
reduce_index
);
size_t
dx_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
dx_reduce_strides_
,
reduce_index
);
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
dy_offset
=
dy_invariant_offset
+
dy_reduce_offset
;
auto
dx_offset
=
dx_invariant_offset
+
dx_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_offset
]);
AccDataType
norm_x
=
(
x
-
mean
)
*
invVar
;
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
dy_offset
]);
arg
.
dy_elementwise_op_
(
dy
,
dy
);
AccDataType
tmpVal
=
norm_x
*
dscale
;
AccDataType
dx
=
multiplier
*
(
type_convert
<
AccDataType
>
(
arg
.
reduceSize_
)
*
dy
-
dbias
-
tmpVal
);
arg
.
p_dx_
[
dx_offset
]
=
type_convert
<
DxDataType
>
(
dx
);
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
i_end
=
std
::
min
(
static_cast
<
size_t
>
((
it
+
1
)
*
work_per_thread
),
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
++
i
)
{
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]);
}
};
threads
[
it
]
=
joinable_thread
(
f
);
}
return
(
0.0
f
);
};
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
};
};
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
p_arg
)
override
{
(
void
)
p_arg
;
return
(
true
);
};
std
::
unique_ptr
<
device
::
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_dy
,
const
void
*
p_scale
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
void
*
p_dx
,
void
*
p_dscale
,
void
*
p_dbias
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
dxStrides
,
dyStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnDscaleDbiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
DyDataType
*>
(
p_dy
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
epsilon
,
dy_elementwise_op
,
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
DscaleDbiasDataType
*>
(
p_dscale
),
static_cast
<
DscaleDbiasDataType
*>
(
p_dbias
));
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Reference_BatchNorm_Backward"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp
deleted
100644 → 0
View file @
e59daa22
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
XDataType
,
typename
DyDataType
,
typename
DxDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
>
struct
ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormBwd
<
4
,
3
,
DyElementwiseOp
>
{
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
dyStrides
,
const
std
::
array
<
index_t
,
4
>
dxStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
ck
::
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
ck
::
index_t
,
1
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
DyDataType
*
p_dy
,
const
ScaleDataType
*
p_scale
,
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedInvVar
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
p_dx
,
ScaleDataType
*
p_dscale
,
BiasDataType
*
p_dbias
)
:
p_x_
(
p_x
),
p_dy_
(
p_dy
),
p_scale_
(
p_scale
),
p_savedMean_
(
p_savedMean
),
p_savedInvVar_
(
p_savedInvVar
),
epsilon_
(
epsilon
),
dy_elementwise_op_
(
dy_elementwise_op
),
p_dx_
(
p_dx
),
p_dscale_
(
p_dscale
),
p_dbias_
(
p_dbias
)
{
ignore
=
xStrides
;
ignore
=
dyStrides
;
ignore
=
dxStrides
;
ignore
=
bnScaleStrides
;
ignore
=
bnBiasStrides
;
ignore
=
bnMeanVarStrides
;
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
if
(
reduceDims
[
0
]
!=
0
||
reduceDims
[
1
]
!=
1
||
reduceDims
[
2
]
!=
2
)
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
n_
=
xyLengths
[
0
];
h_
=
xyLengths
[
1
];
w_
=
xyLengths
[
2
];
c_
=
xyLengths
[
3
];
haveSavedMeanInvVar_
=
(
p_savedMean
!=
nullptr
&&
p_savedInvVar
!=
nullptr
);
}
const
XDataType
*
p_x_
;
const
DyDataType
*
p_dy_
;
const
ScaleDataType
*
p_scale_
;
const
MeanVarDataType
*
p_savedMean_
;
const
MeanVarDataType
*
p_savedInvVar_
;
double
epsilon_
;
const
DyElementwiseOp
dy_elementwise_op_
;
DxDataType
*
p_dx_
;
ScaleDataType
*
p_dscale_
;
BiasDataType
*
p_dbias_
;
bool
haveSavedMeanInvVar_
;
index_t
n_
,
h_
,
w_
,
c_
;
};
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
AccDataType
reduceSize
=
type_convert
<
AccDataType
>
(
arg
.
n_
)
*
type_convert
<
AccDataType
>
(
arg
.
h_
)
*
type_convert
<
AccDataType
>
(
arg
.
w_
);
index_t
offset_C
=
iC
;
AccDataType
mean
;
AccDataType
invVar
;
if
(
arg
.
haveSavedMeanInvVar_
)
{
mean
=
arg
.
p_savedMean_
[
offset_C
];
invVar
=
arg
.
p_savedInvVar_
[
offset_C
];
}
else
{
AccDataType
meansquare
;
meansquare
=
type_convert
<
AccDataType
>
(
0.0
f
);
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
// compute mean, meanquare, variance, inv-variance
for
(
index_t
iN
=
0
;
iN
<
arg
.
n_
;
iN
++
)
{
index_t
offset_N
=
iN
*
arg
.
h_
*
arg
.
w_
*
arg
.
c_
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h_
;
iH
++
)
{
index_t
offset_H
=
iH
*
arg
.
w_
*
arg
.
c_
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w_
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c_
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
mean
+=
x
;
meansquare
+=
x
*
x
;
};
}
};
mean
=
mean
/
reduceSize
;
meansquare
=
meansquare
/
reduceSize
;
AccDataType
variance
=
meansquare
-
mean
*
mean
;
invVar
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
std
::
sqrt
(
type_convert
<
AccDataType
>
(
arg
.
epsilon_
)
+
variance
);
};
AccDataType
dbias
=
type_convert
<
AccDataType
>
(
0.0
f
);
// Sum on NHW of dy
AccDataType
dscale
=
type_convert
<
AccDataType
>
(
0.0
f
);
// Sum on NHW of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on NHW dimensions
// 3) calculate sum(dy * norm_x) on NHW dimensions
for
(
index_t
iN
=
0
;
iN
<
arg
.
n_
;
iN
++
)
{
index_t
offset_N
=
iN
*
arg
.
h_
*
arg
.
w_
*
arg
.
c_
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h_
;
iH
++
)
{
index_t
offset_H
=
iH
*
arg
.
w_
*
arg
.
c_
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w_
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c_
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
norm_x
=
(
x
-
mean
)
*
invVar
;
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
offset
]);
arg
.
dy_elementwise_op_
(
dy
,
dy
);
dbias
+=
dy
;
dscale
+=
norm_x
*
dy
;
};
}
};
arg
.
p_dscale_
[
offset_C
]
=
type_convert
<
ScaleDataType
>
(
dscale
);
arg
.
p_dbias_
[
offset_C
]
=
type_convert
<
BiasDataType
>
(
dbias
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
p_scale_
[
offset_C
]);
AccDataType
multiplier
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
reduceSize
*
invVar
*
scale
;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
for
(
index_t
iN
=
0
;
iN
<
arg
.
n_
;
iN
++
)
{
index_t
offset_N
=
iN
*
arg
.
h_
*
arg
.
w_
*
arg
.
c_
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h_
;
iH
++
)
{
index_t
offset_H
=
iH
*
arg
.
w_
*
arg
.
c_
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w_
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c_
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
norm_x
=
(
x
-
mean
)
*
invVar
;
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
offset
]);
arg
.
dy_elementwise_op_
(
dy
,
dy
);
AccDataType
tmpVal
=
norm_x
*
dscale
;
AccDataType
dx
=
multiplier
*
(
reduceSize
*
dy
-
dbias
-
tmpVal
);
arg
.
p_dx_
[
offset
]
=
type_convert
<
XDataType
>
(
dx
);
};
}
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
c_
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c_
);
auto
f
=
[
=
]
{
for
(
std
::
size_t
ic
=
ic_begin
;
ic
<
ic_end
;
++
ic
)
{
thread_reduce_func
(
ic
);
}
};
threads
[
it
]
=
joinable_thread
(
f
);
}
return
(
0.0
f
);
};
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
};
};
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
p_arg
)
override
{
(
void
)
p_arg
;
return
(
true
);
};
std
::
unique_ptr
<
device
::
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
dyStrides
,
const
std
::
array
<
index_t
,
4
>
dxStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
ck
::
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
ck
::
index_t
,
1
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_dy
,
const
void
*
p_scale
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
void
*
p_dx
,
void
*
p_dscale
,
void
*
p_dbias
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
dyStrides
,
dxStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnBiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
DyDataType
*>
(
p_dy
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
epsilon
,
dy_elementwise_op
,
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
ScaleDataType
*>
(
p_dscale
),
static_cast
<
BiasDataType
*>
(
p_dbias
));
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Reference_BatchNorm_Backward_NHWC_C<"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
7bcaf2a7
...
...
@@ -26,9 +26,9 @@ using Empty_Tuple = ck::Tuple<>;
using
F16_Tuple
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_Tuple
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F32_Tuple
=
ck
::
Tuple
<
F32
>
;
using
I32_Tuple
=
ck
::
Tuple
<
I32
>
;
using
F32_Tuple
=
ck
::
Tuple
<
F32
>
;
using
I32_Tuple
=
ck
::
Tuple
<
I32
>
;
using
I32_
F32_
Tuple
=
ck
::
Tuple
<
I32
,
F32
>
;
// GEMM layout
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
...
@@ -78,8 +78,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
//
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK_TUPLE
=
ck
::
Tuple
<
GK
>
;
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK_Tuple
=
ck
::
Tuple
<
GK
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
GK
,
GK
>
;
// pointwise functor
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -97,6 +98,13 @@ template <typename Activation>
using
Add_Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
Activation
>
using
Activation_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul2_Clamp
<
Activation
>
;
template
<
typename
Activation
>
using
Add_Activation_Mul2_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul2_Clamp
<
Activation
>
;
template
<
typename
DeviceOp
,
typename
Tag
=
void
>
struct
DeviceOperationInstanceFactory
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp
0 → 100644
View file @
7bcaf2a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// FP16
void
add_device_batchnorm_backward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormBwd
<
F16
,
F32
,
F32
,
F32
,
F16
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP32
void
add_device_batchnorm_backward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormBwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// BF16
void
add_device_batchnorm_backward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormBwd
<
BF16
,
F32
,
F32
,
F32
,
BF16
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP64
void
add_device_batchnorm_backward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormBwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
);
template
<
typename
XDataType
,
typename
DxDataType
,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
DscaleDbiasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchNormBwd
<
XDataType
,
DxDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceBatchNormBwd
<
XDataType
,
DxDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
DxDataType
,
F32
>
&&
is_same_v
<
DyDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F16
>
&&
is_same_v
<
DscaleDbiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
DyElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_backward_rank_4_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
DxDataType
,
F32
>
&&
is_same_v
<
DyDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
DscaleDbiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
DyElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_backward_rank_4_3_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
DxDataType
,
F32
>
&&
is_same_v
<
DyDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
DscaleDbiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
DyElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
DxDataType
,
F64
>
&&
is_same_v
<
DyDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
DscaleDbiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
DyElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_backward_rank_4_3_f64_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
7bcaf2a7
...
...
@@ -131,6 +131,47 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
Empty_Tuple
,
GNHWK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
Empty_Tuple
,
GNHWK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
Empty_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
...
@@ -273,11 +314,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
...
...
@@ -289,6 +332,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
deleted
100644 → 0
View file @
e59daa22
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
0 → 100644
View file @
7bcaf2a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void
add_device_conv2d_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Activation_Mul2_Clamp
<
PassThrough
>>>>&
instances
);
void
add_device_conv2d_bias_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
I32_F32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Activation_Mul2_Clamp
<
Relu
>>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DsLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DsDataType
,
typename
OutDataType
,
typename
Activation
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Add_Activation_Mul2_Clamp
<
Activation
>>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Add_Activation_Mul2_Clamp
<
Activation
>>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_GK_Tuple
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_F32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
add_device_conv2d_bias_perchannel_quantization_int8_instances
(
op_ptrs
);
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
add_device_conv2d_bias_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/
quantization/
grouped_convolution_bias_forward_perlayer_quantization.hpp
View file @
7bcaf2a7
...
...
@@ -23,7 +23,7 @@ void add_device_conv2d_bias_perlayer_quantization_int8_instances(
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_T
UPLE
,
GK_T
uple
,
GNHWK
,
int8_t
,
int8_t
,
...
...
@@ -38,7 +38,7 @@ void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_T
UPLE
,
GK_T
uple
,
GNHWK
,
int8_t
,
int8_t
,
...
...
@@ -91,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_T
UPLE
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_T
uple
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
0 → 100644
View file @
7bcaf2a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void
add_device_conv2d_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
F32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Activation_Mul2_Clamp
<
PassThrough
>>>>&
instances
);
void
add_device_conv2d_relu_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_Tuple
,
GNHWK
,
int8_t
,
int8_t
,
F32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Activation_Mul2_Clamp
<
Relu
>>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DsLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DsDataType
,
typename
OutDataType
,
typename
Activation
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Activation_Mul2_Clamp
<
Activation
>>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
GK_Tuple
,
OutLayout
,
InDataType
,
WeiDataType
,
F32_Tuple
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Activation_Mul2_Clamp
<
Activation
>>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_Tuple
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
add_device_conv2d_perchannel_quantization_int8_instances
(
op_ptrs
);
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
add_device_conv2d_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/
quantization/
grouped_convolution_forward_perlayer_quantization.hpp
View file @
7bcaf2a7
File moved
Prev
1
2
3
4
5
6
7
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment