Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8dd7156d
Commit
8dd7156d
authored
Jul 25, 2023
by
ltqin
Browse files
Merge branch 'mha-train-develop' into attn-train-develop-qloop-mask
parents
d5f629e7
b5a3ea2d
Changes
533
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1412 additions
and
136 deletions
+1412
-136
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+1
-1
include/ck/tensor_operation/gpu/device/tensor_specialization.hpp
.../ck/tensor_operation/gpu/device/tensor_specialization.hpp
+1
-1
include/ck/tensor_operation/gpu/device/welford_helper.hpp
include/ck/tensor_operation/gpu/device/welford_helper.hpp
+1
-1
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+66
-27
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+129
-39
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
...k/tensor_operation/gpu/element/quantization_operation.hpp
+177
-15
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+176
-20
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
...norm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
+704
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
...ultiblock_reduce_second_half_batchnorm_backward_final.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
...orm_multiblock/gridwise_multiblock_welford_first_half.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp
..._welford_second_half_batchnorm_forward_final_obsolete.hpp
+10
-11
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
...lock_welford_second_half_multiblock_reduce_first_half.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+142
-16
No files found.
Too many changes to show.
To preserve performance only
533 of 533+
files are displayed.
Plain diff
Email patch
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/tensor_specialization.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/device/welford_helper.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -49,6 +50,14 @@ struct Add
y
=
x0
+
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
y
=
x0
+
x1_tmp
;
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x0
,
const
bhalf_t
&
x1
)
const
...
...
@@ -67,6 +76,30 @@ struct Add
};
};
struct
ScaleAdd
{
__host__
__device__
ScaleAdd
(
float
scale
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
half_t
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
y
=
scale_
*
x0
+
ck
::
type_convert
<
float
>
(
x1
);
};
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
bhalf_t
>
(
float
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
y
=
scale_
*
x0
+
ck
::
type_convert
<
float
>
(
x1
);
};
float
scale_
;
};
struct
Subtract
{
template
<
typename
T
>
...
...
@@ -118,6 +151,13 @@ struct Bilinear
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
,
const
X0
&
,
const
X1
&
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
double
,
double
,
double
>
(
double
&
y
,
const
double
&
x0
,
const
double
&
x1
)
const
{
y
=
alpha_
*
x0
+
beta_
*
x1
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
...
...
@@ -241,43 +281,42 @@ struct AddHardswish
};
};
// C = A * B
// E = FastGelu(C + D)
struct
AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
;
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d
)
const
{
static_assert
(
is_valid_param_type_v
<
E
>
&&
is_valid_param_type_v
<
C
>
&&
is_valid_param_type_v
<
D
>
);
const
float
x
=
c
+
d
;
const
float
y
=
GetFastGeLU
(
type_convert
<
float
>
(
c
)
+
type_convert
<
float
>
(
d
));
FastGelu
{}.
template
operator
()
<
float
,
float
>(
e
,
x
);
}
e
=
type_convert
<
E
>
(
y
);
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d
)
const
{
const
half_t
x
=
c
+
d
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
}
template
<
typename
D
>
__host__
__device__
constexpr
void
operator
()(
float
&
e
,
const
float
&
c
,
const
D
&
d
)
const
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
float
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d
)
const
{
static_assert
(
is_valid_param_type_v
<
D
>
);
const
float
x0_f
=
c
+
d
;
float
x1_f
=
0
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
GetFastGeLU
(
c
+
type_convert
<
float
>
(
d
)
);
e
=
type_convert
<
half_t
>
(
x1_f
);
}
};
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -16,7 +16,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
//
//
Method 1
:
//
Example
:
//
// struct ExampleElementwiseOp
// {
...
...
@@ -30,19 +30,6 @@ namespace element_wise {
// {
// }
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct
AddReluAdd
{
...
...
@@ -173,40 +160,109 @@ struct AddAdd
};
// C = A * B
// E = (C + D0) x D1
struct
AddMultiply
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
y
=
(
c
+
d0
)
*
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
,
half_t
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
y
=
(
type_convert
<
half_t
>
(
c
)
+
d0
)
*
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
half_t
,
half_t
>
(
float
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
float
y
=
(
c
+
d0
)
*
d1
;
e
=
y
;
}
};
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
const
float
x
=
c
+
d0
+
d1
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
e
,
x
);
}
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
ck
::
int4_t
>
#endif
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
x
=
c
+
d0
+
d1
;
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
float
,
half_t
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
float
x0_f
=
c
+
d0
+
d1
;
float
x1_f
=
0
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
half_t
>
(
x1_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
e
,
const
float
&
c
,
const
bhalf_t
&
d0
,
const
bhalf_t
&
d1
)
const
{
static_assert
(
is_valid_param_type_v
<
E
>
&&
is_valid_param_type_v
<
C
>
&&
is_valid_param_type_v
<
D0
>
&&
is_valid_param_type_v
<
D1
>
);
const
float
x0_f
=
c
+
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
);
const
float
y
=
GetFastGeLU
(
type_convert
<
float
>
(
c
)
+
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
));
float
x1_f
=
0
;
e
=
type_convert
<
E
>
(
y
);
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
bhalf_t
>
(
x1_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int32_t
,
int8_t
,
int8_t
>
(
int8_t
&
e
,
const
int32_t
&
c
,
const
int8_t
&
d0
,
const
int8_t
&
d1
)
const
{
const
float
x0_f
=
type_convert
<
float
>
(
c
)
+
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
);
float
x1_f
=
0
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
int8_t
>
(
x1_f
);
}
};
...
...
@@ -278,6 +334,40 @@ struct Normalize
double
epsilon_
;
};
// used by BatchNorm inference
// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
// The data type of mean and variance is used as AccDataType
struct
NormalizeInInfer
{
NormalizeInInfer
(
double
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
template
<
typename
T1
,
typename
T2
,
typename
T3
,
typename
T4
>
__host__
__device__
constexpr
void
operator
()(
T1
&
y
,
const
T1
&
x
,
const
T2
&
mean
,
const
T2
&
variance
,
const
T3
&
gamma
,
const
T4
&
beta
)
const
{
static_assert
(
std
::
is_same
<
T2
,
float
>::
value
||
std
::
is_same
<
T2
,
double
>::
value
,
"Data type is not supported by this operation!"
);
using
ck
::
type_convert
;
using
ck
::
math
::
sqrt
;
T2
tmp_x
,
tmp_y
;
tmp_x
=
type_convert
<
T2
>
(
x
);
tmp_y
=
((
tmp_x
-
mean
)
/
sqrt
(
variance
+
type_convert
<
T2
>
(
epsilon_
)))
*
type_convert
<
T2
>
(
gamma
)
+
type_convert
<
T2
>
(
beta
);
y
=
type_convert
<
T1
>
(
tmp_y
);
};
double
epsilon_
;
};
template
<
typename
Y
,
typename
X
>
struct
UnaryTypeConvert
;
...
...
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
View file @
8dd7156d
#pragma once
#include "ck/utility/data_type.hpp"
// #include "ck/utility/get_id.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
// Y = Sy * Qy
// W = Sw * Qw
// X = Sx * Qx
// B = Sb * Qb = Sw * Sx * Qb
// Where X, W, Y are float32, Qx, Qw, Qy are int8
// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range
// Qb is int32, scale of B is Sw * Sx for convenient
// Y = W @ X, where @ is convolution or matrix multiplication
// Sy * Qy = Sw * Qw @ Sx * Qx
// Qy = [(Sw*Sx)/Sy] * Qw @ Qx
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template
<
typename
Activation
>
struct
Activation_Mul_Clamp
{
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx)
// requantScale_ = Sw * Sx / Sz
Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
{
...
...
@@ -17,26 +38,66 @@ struct Activation_Mul_Clamp
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
float
x_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
x_fp32
,
x_fp32
);
float
y_fp32
=
math
::
clamp
(
requantScale_
*
x_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
__device__
constexpr
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int32_t
>
(
y_fp32
);
}
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
int32_
t
&
x
)
const
__host__
constexpr
void
operator
()(
float
&
y
,
const
floa
t
&
x
)
const
{
// We might type_convert to int8 after lambda in someplace
float
x_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
x_fp32
,
x_fp32
);
y
=
math
::
clamp
(
requantScale_
*
x_fp32
,
-
128.
f
,
127.
f
);
// CAUSION - We might float in & float out in reference code
activationOp_
(
y
,
x
);
y
=
math
::
clamp
(
requantScale_
*
y
,
-
128.
f
,
127.
f
);
}
float
requantScale_
;
Activation
activationOp_
;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template
<
typename
Activation
>
struct
Mul_Activation_Mul_Clamp
{
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Mul_Activation_Mul_Clamp
(
float
scale_z_inv
,
float
scaleAcc
,
Activation
activationOp
)
:
scale_z_inv_
(
scale_z_inv
),
scaleAcc_
(
scaleAcc
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
);
y_fp32
=
scaleAcc_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
scale_z_inv_
;
float
scaleAcc_
;
Activation
activationOp_
;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template
<
typename
Activation
>
struct
Activation_Mul2_Clamp
{
...
...
@@ -51,13 +112,35 @@ struct Activation_Mul2_Clamp
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
__device__
constexpr
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
,
const
float
&
requantScale
)
const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int32_t
>
(
y_fp32
);
}
Activation
activationOp_
;
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template
<
typename
Activation
>
struct
Add_Activation_Mul_Clamp
{
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
// Y = W @ X + B
// Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb
// Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb)
// For activation, Z = Activaiton(Y)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb)
Add_Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
{
...
...
@@ -72,6 +155,17 @@ struct Add_Activation_Mul_Clamp
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
__host__
__device__
constexpr
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int32_t
>
(
y_fp32
);
}
float
requantScale_
;
Activation
activationOp_
;
};
...
...
@@ -92,15 +186,33 @@ struct Add_Activation_Mul2_Clamp
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
__host__
__device__
constexpr
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
,
const
float
&
requantScale
)
const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int32_t
>
(
y_fp32
);
}
Activation
activationOp_
;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template
<
typename
Activation
>
struct
Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp
(
float
requantScale1
,
float
requantScale2
,
Activation
activationOp
)
:
requantScale1_
(
requantScale1
),
requantScale2_
(
requantScale2
),
activationOp_
(
activationOp
)
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Add_Mul_Activation_Mul_Clamp
(
float
scale_z_inv
,
float
scaleAcc
,
Activation
activationOp
)
:
scale_z_inv_
(
scale_z_inv
),
scaleAcc_
(
scaleAcc
),
activationOp_
(
activationOp
)
{
}
...
...
@@ -108,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
requantScale1_
*
y_fp32
;
y_fp32
=
scaleAcc_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
__host__
__device__
constexpr
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
scaleAcc_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int32_t
>
(
y_fp32
);
}
float
scale_z_inv_
;
float
scaleAcc_
;
Activation
activationOp_
;
};
// Conv Perchannel quantization + Activation function which is non piecewise linear function,
// such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy *Qy) != Sy * Activation(Qy)
template
<
typename
Activation
>
struct
Add_Mul2_Activation_Mul_Clamp
{
Add_Mul2_Activation_Mul_Clamp
(
float
scale_z_inv
,
Activation
activationOp
)
:
scale_z_inv_
(
scale_z_inv
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
,
const
float
&
scaleAcc
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
scaleAcc
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale2
_
*
y_fp32
,
-
128.
f
,
127.
f
);
y_fp32
=
math
::
clamp
(
scale_z_inv
_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
requantScale1_
;
float
requantScale2_
;
__host__
__device__
constexpr
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
,
const
float
&
scaleAcc
)
const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
scaleAcc
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int32_t
>
(
y_fp32
);
}
float
scale_z_inv_
;
Activation
activationOp_
;
};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
#if CK_WORKAROUND_SWDEV_383542
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
struct
PassThrough
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -52,6 +57,12 @@ struct PassThrough
y
=
type_convert
<
bhalf_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
half_t
>
(
bhalf_t
&
y
,
const
half_t
&
x
)
const
{
y
=
type_convert
<
bhalf_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
...
...
@@ -71,6 +82,36 @@ struct PassThrough
y
=
x
;
}
#endif
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
f8_t
>
(
float
&
y
,
const
f8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
f8_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
f8_t
>
(
half_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
half_t
>
(
f8_t
&
y
,
const
half_t
&
x
)
const
{
y
=
type_convert
<
f8_t
>
(
x
);
}
};
struct
UnaryConvert
...
...
@@ -82,6 +123,40 @@ struct UnaryConvert
}
};
struct
ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
bhalf_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
bf16_convert_rtn
<
Y
>
(
x
);
}
};
struct
ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
f8_convert_sr
<
Y
>
(
x
);
}
};
struct
Scale
{
__host__
__device__
Scale
(
float
scale
)
:
scale_
(
scale
)
{}
...
...
@@ -95,6 +170,12 @@ struct Scale
y
=
scale_
*
x
;
};
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
scale_
*
x
;
};
__host__
__device__
auto
Value
()
const
{
return
scale_
;
}
float
scale_
;
...
...
@@ -196,36 +277,83 @@ struct Relu
}
};
// Y = FastGelu(X)
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "__expf" and "rcp" function
struct
FastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
template
<
typename
Y
,
typename
X
>
__host__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
y
=
x
*
cdf
;
}
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
ck
::
int4_t
>
// device code, use lower precision "__expf" and "rcp"
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
__expf
(
-
u
);
#if !CK_WORKAROUND_SWDEV_383542
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__frcp_rn
(
1.
f
+
emu
)
-
1.
f
);
#else
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
*
__ocml_native_recip_f32
(
1.
f
+
emu
)
-
1.
f
);
#endif
;
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
y
=
x
*
cdf
;
}
template
<
>
__host__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
static_assert
(
is_valid_param_type_v
<
Y
>
&&
is_valid_param_type_v
<
X
>
);
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
const
float
tmp_y
=
GetFastGeLU
(
type_convert
<
float
>
(
x
));
y
=
type_convert
<
Y
>
(
tmp_y
);
y
=
type_convert
<
half_t
>
(
y_f
);
}
template
<
>
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
half_t
>
(
y_f
);
}
template
<
>
__host__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
half_t
>
(
y_f
);
}
template
<
>
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
half_t
>
(
y_f
);
}
};
...
...
@@ -261,8 +389,36 @@ struct Sigmoid
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
};
};
int32_t
divider_
=
1
;
struct
TanH
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
};
};
struct
Swish
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
(
ck
::
type_convert
<
T
>
(
1
)
+
ck
::
math
::
exp
(
-
beta_
*
x
));
};
float
beta_
=
1.0
f
;
};
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
0 → 100644
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_synchronization.hpp"
namespace
ck
{
template
<
typename
GridwiseMultiblockBatchNormForward_
,
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_multiblock_batchnorm_forward
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
y_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
mean_var_count_grid_desc_m_g
,
const
MeanVarCountGridDesc_M_K
mean_var_count_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
,
int32_t
*
const
__restrict__
p_control
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
GridwiseMultiblockBatchNormForward_
::
Run
(
x_grid_desc_m_k
,
y_grid_desc_m_k
,
mean_var_count_grid_desc_m_g
,
mean_var_count_grid_desc_m_k
,
scale_grid_desc_m
,
bias_grid_desc_m
,
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
num_k_block_tile_iteration
,
epsilon
,
p_x
,
p_welford_mean
,
p_welford_variance
,
p_welford_count
,
p_control
,
p_scale
,
p_bias
,
y_elementwise_op
,
p_y
,
updateMovingAverage
,
averageFactor
,
resultRunningMean
,
resultRunningVariance
,
saveMeanInvVariance
,
resultSaveMean
,
resultSaveInvVariance
);
};
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseMultiblockBatchNormForward
{
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcYDstVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadReduceSrcDesc_M_1
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
using
ThreadwiseWelford1
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
ThreadwiseWelford2
=
ThreadwiseWelfordMerge
<
AccDataType
,
ThreadReduceSrcDesc_M_1
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford1
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
BlockwiseWelford2
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
true
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
y_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
&
mean_var_count_grid_desc_m_g
,
const
MeanVarCountGridDesc_M_K
&
mean_var_count_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
,
int32_t
*
const
__restrict__
p_control
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
using
ck
::
math
::
sqrt
;
const
index_t
blkgroup_size
=
mean_var_count_grid_desc_m_g
.
GetLength
(
I1
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
if
(
block_local_id
==
0
)
gms_init
(
BlockSize
/
warpSize
*
blkgroup_size
,
&
p_control
[
blkgroup_id
*
2
]);
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
tmp_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
tmp_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
tmp_count_thread_buf
;
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
constexpr
auto
xy_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
// Step 1: each workgroup does local welford reduction
auto
threadwise_welford_1
=
ThreadwiseWelford1
();
threadwise_welford_1
.
max_count_
=
get_reduce_count_per_thread
(
block_local_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
threadwise_welford_1
.
Run
(
x_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
count_thread_buf
(
I
)
=
threadwise_welford_1
.
cur_count_
;
BlockwiseWelford1
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count_thread_buf
(
I
));
});
// Step 2: each workgroup writes its local welford result to workspace memory
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_mean
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_variance
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_count
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
threadwise_mean_var_store_m_g
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_count_store_m_g
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_mean_var_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
,
mean_var_count_grid_desc_m_g
,
mean_global_val_buf
);
threadwise_mean_var_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
var_thread_buf
,
mean_var_count_grid_desc_m_g
,
var_global_val_buf
);
threadwise_count_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
count_thread_buf
,
mean_var_count_grid_desc_m_g
,
count_global_val_buf
);
};
gms_barrier
(
&
p_control
[
blkgroup_id
*
2
]);
if
(
block_local_id
==
0
)
gms_reset
(
&
p_control
[
blkgroup_id
*
2
]);
// Step 3: each workgroup reads welford results from workspace memory and does final welford
// reduction
auto
threadwise_mean_var_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_count_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
count_thread_buf
(
I
)
=
0
;
});
constexpr
auto
mean_var_count_read_fwd_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
);
int32_t
reducedSize
=
0
;
while
(
reducedSize
<
blkgroup_size
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_mean_thread_buf
);
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
var_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_var_thread_buf
);
threadwise_count_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
count_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_count_thread_buf
);
ThreadwiseWelford2
::
Run
(
tmp_mean_thread_buf
,
tmp_var_thread_buf
,
tmp_count_thread_buf
,
mean_thread_buf
,
var_thread_buf
,
count_thread_buf
);
reducedSize
+=
KThreadClusterSize
;
threadwise_mean_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_read_fwd_step_m_k
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_read_fwd_step_m_k
);
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford2
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count_thread_buf
(
I
));
});
// Step 4: do normalization using the mean/variance
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
YElementwiseOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
y_elementwise_op
);
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_bias_load
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasSrcVectorSize
,
1
,
true
>
(
bias_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
const
auto
scale_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
bias_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_bias_load
.
Run
(
bias_grid_desc_m
,
bias_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_thread_buf
);
threadwise_x_load
.
SetSrcSliceOrigin
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
multiplier
=
scale_thread_buf
[
Number
<
iM
>
{}]
/
sqrt
(
var_thread_buf
[
iM
]
+
epsilon
);
AccDataType
fused_mean_bias
=
bias_thread_buf
[
Number
<
iM
>
{}]
-
mean_thread_buf
[
iM
]
*
multiplier
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
y_thread_buf
(
Number
<
offset
>
{})
=
x_thread_buf
[
Number
<
offset
>
{}]
*
multiplier
+
fused_mean_bias
;
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
,
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
}
// Step 5: update the moving average of mean and variance (optional)
if
(
updateMovingAverage
&&
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_var_thread_buf
;
auto
running_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
running_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
);
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
averageFactor
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
running_mean_thread_buf
(
I
)
=
running_mean_thread_buf
[
I
]
*
oneMinusAverageFactor
+
mean_thread_buf
[
I
]
*
averageFactor
;
running_var_thread_buf
(
I
)
=
running_var_thread_buf
[
I
]
*
oneMinusAverageFactor
+
var_thread_buf
[
I
]
*
averageFactor
;
});
auto
threadwise_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
,
mean_var_grid_desc_m
,
running_mean_global_buf
);
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
,
mean_var_grid_desc_m
,
running_var_global_buf
);
};
// Step 6: save mean and inv-variance (optional)
if
(
saveMeanInvVariance
&&
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
auto
result_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
result_inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveInvVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
sqrt
(
epsilon
+
var_thread_buf
[
I
]);
});
auto
threadwise_mean_inv_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
mean_var_grid_desc_m
,
result_mean_global_buf
);
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
var_thread_buf
,
mean_var_grid_desc_m
,
result_inv_var_global_buf
);
};
}
};
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
...
...
@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp
→
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final
_obsolete
.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
...
...
@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m
,
blkgroup_size
,
num_xy_k_block_tile_iteration
,
num_mean_var_count_k_block_tile_iteration
,
epsilon
,
p_in_welford_mean
,
p_in_welford_variance
,
...
...
@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
...
...
@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
1
,
true
>
(
...
...
@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
1
,
true
>
(
...
...
@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
constexpr
auto
mean_var_count_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
// Step 1: do final welford reduction to get mean and variance
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf
(
I
)
=
0
;
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_mean_var_count_k_block_tile_iteration
;
++
reducedTiles
)
constexpr
auto
mean_var_count_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
);
int32_t
reducedSize
=
0
;
while
(
reducedSize
<
blkgroup_size
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_mean_global_val_buf
,
...
...
@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf
,
welford_count_thread_buf
);
reducedSize
+=
KThreadClusterSize
;
threadwise_mean_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_thread_copy_step_m_k
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -109,30 +109,57 @@ struct BlockToCTileMap_M00_N0_M01
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_M00_N0_M01Adapt
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
=
void
>
struct
BlockToCTileMap_M00_N0_M01Adapt
;
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
void
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
M01_
(
M01
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
:
BlockToCTileMap_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
)
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
)
,
NPerBlock
);
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
const
index_t
grid_size
=
M0
*
N0
;
return
M0
*
N0
;
}
return
grid_size
;
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
...
...
@@ -140,8 +167,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I0
)
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I1
)
,
NPerBlock
);
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
...
...
@@ -154,6 +181,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
...
...
@@ -165,11 +236,18 @@ struct BlockToCTileMap_M00_N0_M01Adapt
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
index_t
M_
;
index_t
N_
;
index_t
M01_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
};
// keep the redundant type argument for backward compatibility
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_M00_N0_M01Adapt
:
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
void
>
{
using
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
void
>::
BlockToCTileMap_M00_N0_M01Adapt
;
};
// 2D slices of column-vectors in 3D space
...
...
@@ -543,4 +621,52 @@ struct OffsettedBlockToCTileMap
index_t
block_start_
;
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.
*
* @paragraph Description
* This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
* blocks. The first 2D are regular 2D tiles created by division of output GEMM
* dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension,
* which denotes the number of blocks we use to divide work on GEMM K dimension onto.
*
* @tparam MPerBlock Output block tile size in M dimension.
* @tparam NPerBlock Output block tile size in N dimension.
*/
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_3DGrid_KSplit
{
__host__
__device__
BlockToCTileMap_3DGrid_KSplit
()
=
default
;
__host__
__device__
constexpr
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
k_split
)
const
{
// Create 3D grid
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
std
::
make_tuple
(
N0
,
M0
,
k_split
);
}
template
<
typename
TopIdx
>
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
)
const
{
return
make_tuple
(
blockIdx
.
z
,
blockIdx
.
y
,
blockIdx
.
x
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
};
}
// namespace ck
Prev
1
…
23
24
25
26
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment