Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
defa2071
Commit
defa2071
authored
Nov 15, 2023
by
Adam Osewski
Browse files
Merge branch 'develop' into aosewski/ggemm_multi_d2
parents
28a68428
f2398f61
Changes
438
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1621 additions
and
219 deletions
+1621
-219
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp
...eration/gpu/device/impl/device_normalization_fwd_impl.hpp
+10
-10
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
.../gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
+10
-10
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+24
-2
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+65
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+182
-8
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp
...nsor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp
+224
-0
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
+264
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+15
-9
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+91
-46
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+32
-11
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
...d/normalization/gridwise_normalization_bwd_gamma_beta.hpp
+343
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+6
-1
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+98
-49
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+2
-0
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+0
-22
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+184
-8
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+1
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+38
-18
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+21
-20
library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp
.../reference_tensor_operation/cpu/reference_contraction.hpp
+11
-5
No files found.
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_normalization_
fwd_
impl.hpp
View file @
defa2071
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization
_fwd
.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp"
...
...
@@ -46,14 +46,14 @@ template <typename XDataType,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
=
true
>
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
struct
DeviceNormalization
Fwd
Impl
:
public
DeviceNormalization
Fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
{
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
...
...
@@ -461,7 +461,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationImpl<"
<<
BlockSize
<<
","
;
str
<<
"DeviceNormalization
Fwd
Impl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYSrcVectorDim
<<
","
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_normalization_
fwd_
splitk_impl.hpp
View file @
defa2071
...
...
@@ -8,7 +8,7 @@
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization
_fwd
.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
...
...
@@ -134,14 +134,14 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
struct
DeviceNormalization
Fwd
SplitKImpl
:
public
DeviceNormalization
Fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
{
using
WorkspaceMeanVarDataType
=
SaveMeanInvStdDataType
;
...
...
@@ -732,7 +732,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationSplitKImpl<"
<<
BlockSize
<<
","
;
str
<<
"DeviceNormalization
Fwd
SplitKImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYVectorDim
<<
","
;
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
defa2071
...
...
@@ -85,10 +85,13 @@ struct Add
struct
ScaleAdd
{
__host__
__device__
ScaleAdd
(
float
scale
)
:
scale_
(
scale
)
{}
__host__
__device__
ScaleAdd
(
float
scale
=
1.
f
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
ck
::
type_convert
<
Y
>
(
scale_
*
ck
::
type_convert
<
float
>
(
x0
)
+
ck
::
type_convert
<
float
>
(
x1
));
}
template
<
>
__host__
__device__
void
...
...
@@ -186,6 +189,25 @@ struct Bilinear
y
=
type_convert
<
half_t
>
(
alpha_
*
x0
+
beta_
*
ck
::
type_convert
<
float
>
(
x1
));
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x0_tmp
=
type_convert
<
float
>
(
x0
);
const
float
x1_tmp
=
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
alpha_
*
x0_tmp
+
beta_
*
x1_tmp
;
y
=
type_convert
<
bhalf_t
>
(
y_tmp
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
>
(
bhalf_t
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
alpha_
*
x0
+
beta_
*
x1_tmp
;
y
=
y_tmp
;
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
defa2071
...
...
@@ -311,6 +311,71 @@ struct AddAddFastGelu
}
};
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct
ScaleAddScaleAddRelu
{
ScaleAddScaleAddRelu
(
const
float
alpha1
=
1.
f
,
const
float
alpha2
=
1.
f
)
:
alpha1_
(
alpha1
),
alpha2_
(
alpha2
)
{
}
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x
=
c
*
alpha1_
+
alpha2_
*
d0
+
d1
;
Relu
{}.
template
operator
()
<
float
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
);
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
);
e
=
type_convert
<
half_t
>
(
result
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
e
,
const
bhalf_t
&
c
,
const
bhalf_t
&
d0
,
const
bhalf_t
&
d1
)
const
{
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
);
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
);
e
=
type_convert
<
bhalf_t
>
(
result
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int8_t
,
float
,
float
>
(
int8_t
&
e
,
const
int8_t
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
d0
+
d1
;
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
);
e
=
type_convert
<
int8_t
>
(
result
);
}
const
float
alpha1_
;
const
float
alpha2_
;
};
struct
Normalize
{
// FIXME: is double absolutely necessary?
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
defa2071
...
...
@@ -16,6 +16,57 @@ namespace element_wise {
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
// fake conversion
uint16_t
t
=
ck
::
bit_cast
<
uint32_t
>
(
x
);
y
=
ck
::
bit_cast
<
ck
::
f8x2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
float2_t
&
y
,
const
ck
::
float2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
int8x2_t
&
y
,
const
ck
::
int8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf2_t
&
y
,
const
ck
::
bhalf2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
double2_t
&
y
,
const
ck
::
double2_t
&
x
)
const
{
y
=
x
;
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
struct
PassThrough
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -33,6 +84,12 @@ struct PassThrough
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
double
,
float
>
(
double
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -69,6 +126,12 @@ struct PassThrough
y
=
type_convert
<
bhalf_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
bhalf_t
>
(
float
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
half_t
>
(
bhalf_t
&
y
,
const
half_t
&
x
)
const
{
...
...
@@ -207,7 +270,8 @@ struct ConvertF8SR
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
f8_t
>::
value
||
is_same
<
Y
,
bf8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
...
...
@@ -224,6 +288,20 @@ struct Scale
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
type_convert
<
half_t
>
(
scale_
)
*
x
;
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
const
float
x_tmp
=
ck
::
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
};
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -277,8 +355,8 @@ struct UnarySquare
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
int32_t
>
||
is_same_v
<
T
,
int8_t
>
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
half_t
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
int32_t
>
||
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
is_same_v
<
T
,
int4_t
>
#endif
...
...
@@ -442,10 +520,11 @@ struct Sigmoid
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
};
...
...
@@ -455,7 +534,8 @@ struct TanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
...
...
@@ -481,7 +561,101 @@ struct Swish
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
float
beta_
=
1.0
f
;
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp
0 → 100644
View file @
defa2071
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseElementwise1dFunctor
,
typename
InGrid1dDescTuple
,
typename
OutGrid1dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
,
typename
UnaryOperation
,
typename
Scale
>
__global__
void
kernel_elementwise_1d
(
const
InGrid1dDescTuple
in_grid_1d_desc_tuple
,
const
OutGrid1dDescTuple
out_grid_1d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
UnaryOperation
unary_op
,
const
Scale
scale_op
)
{
GridwiseElementwise1dFunctor
::
Run
(
in_grid_1d_desc_tuple
,
out_grid_1d_desc_tuple
,
p_in_global_tuple
,
p_out_global_tuple
,
elementwise_op
,
unary_op
,
scale_op
);
}
template
<
typename
InGrid1dDescTuple
,
typename
OutGrid1dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
,
typename
UnaryOperation
,
typename
Scale
,
index_t
MPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
GridwiseElementwise_1D
{
static
constexpr
index_t
NumInput
=
InDataTypePointerTuple
::
Size
();
static
constexpr
index_t
NumOutput
=
OutDataTypePointerTuple
::
Size
();
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
()
&&
NumInput
==
InGrid1dDescTuple
::
Size
()
&&
NumOutput
==
OutGrid1dDescTuple
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
__device__
static
void
Run
(
const
InGrid1dDescTuple
in_grid_1d_desc_tuple
,
const
OutGrid1dDescTuple
out_grid_1d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
UnaryOperation
unary_op
,
const
Scale
scale_op
)
{
const
index_t
thread_global_id
=
get_thread_global_1d_id
();
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
,
true
>
{};
},
Number
<
NumInput
>
{});
auto
out_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
,
true
>
{};
},
Number
<
NumOutput
>
{});
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
static_assert
(
in_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumInput
>
{});
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
static_assert
(
out_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumOutput
>
{});
const
auto
thread_global_offset
=
make_multi_index
(
thread_global_id
*
MPerThread
);
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
in_grid_1d_desc_tuple
[
I0
].
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
auto
in_global_load_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
in_grid_1d_desc_tuple
[
I
]),
decltype
(
thread_buffer_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
InScalarPerVectorSeq
::
At
(
I
),
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
in_grid_1d_desc_tuple
[
I
],
thread_global_offset
};
},
Number
<
NumInput
>
{});
auto
out_global_store_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
ThreadwiseTensorSliceTransfer_v1r3
<
DataType
,
DataType
,
decltype
(
thread_buffer_desc_m
),
decltype
(
out_grid_1d_desc_tuple
[
I
]),
PassThroughOp
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
OutScalarPerVectorSeq
::
At
(
I
),
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_1d_desc_tuple
[
I
],
thread_global_offset
,
PassThroughOp
{});
},
Number
<
NumOutput
>
{});
index_t
num_iter
=
M
/
(
loop_step
);
do
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
Run
(
in_grid_1d_desc_tuple
[
I
],
in_global_buf_tuple
[
I
],
thread_buffer_desc_m
,
make_tuple
(
I0
),
in_thread_buf_tuple
(
I
));
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_1d_desc_tuple
[
I
],
loop_step_index
);
});
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
iM
)
{
// get reference to in data
auto
uop_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
in_thread_buf_tuple
(
I
)(
iM
);
},
Number
<
NumInput
>
{});
// get reference to dst data
auto
out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
out_thread_buf_tuple
(
I
)(
iM
);
},
Number
<
NumOutput
>
{});
unpack2
(
unary_op
,
uop_data_refs
,
uop_data_refs
);
auto
sop_in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
in_thread_buf_tuple
(
I
)(
iM
);
},
Number
<
NumInput
>
{});
auto
sop_out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
in_thread_buf_tuple
(
I
)(
iM
);
},
Number
<
NumInput
>
{});
unpack2
(
scale_op
,
sop_out_data_refs
,
sop_in_data_refs
);
const
auto
in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
const
auto
&
{
return
in_thread_buf_tuple
(
I
)(
iM
);
},
Number
<
NumInput
>
{});
unpack2
(
elementwise_op
,
out_data_refs
,
in_data_refs
);
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
out_thread_buf_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
],
out_global_buf_tuple
(
I
));
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_1d_desc_tuple
[
I
],
loop_step_index
);
});
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
0 → 100644
View file @
defa2071
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseElementwise3dFunctor
,
typename
InGrid3dDescTuple
,
typename
OutGrid3dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
>
__global__
void
kernel_elementwise_3d
(
const
InGrid3dDescTuple
in_grid_3d_desc_tuple
,
const
OutGrid3dDescTuple
out_grid_3d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
num_threads_m
,
const
index_t
num_threads_n
,
const
index_t
num_threads_k
)
{
GridwiseElementwise3dFunctor
::
Run
(
in_grid_3d_desc_tuple
,
out_grid_3d_desc_tuple
,
p_in_global_tuple
,
p_out_global_tuple
,
elementwise_op
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
}
template
<
typename
InGrid3dDescTuple
,
typename
OutGrid3dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
GridwiseElementwise_3D
{
static
constexpr
index_t
NumInput
=
InDataTypePointerTuple
::
Size
();
static
constexpr
index_t
NumOutput
=
OutDataTypePointerTuple
::
Size
();
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
()
&&
NumInput
==
InGrid3dDescTuple
::
Size
()
&&
NumOutput
==
OutGrid3dDescTuple
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
thread_buffer_desc_mnk
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
NPerThread
>
{},
Number
<
KPerThread
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
__device__
static
void
Run
(
const
InGrid3dDescTuple
in_grid_3d_desc_tuple
,
const
OutGrid3dDescTuple
out_grid_3d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
num_threads_m
,
const
index_t
num_threads_n
,
const
index_t
num_threads_k
)
{
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
*
NPerThread
*
KPerThread
,
true
>
{};
},
Number
<
NumInput
>
{});
auto
out_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
*
NPerThread
*
KPerThread
,
true
>
{};
},
Number
<
NumOutput
>
{});
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_3d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumInput
>
{});
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_3d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumOutput
>
{});
const
auto
M
=
in_grid_3d_desc_tuple
[
I0
].
GetLength
(
I0
);
const
auto
N
=
in_grid_3d_desc_tuple
[
I0
].
GetLength
(
I1
);
const
auto
K
=
in_grid_3d_desc_tuple
[
I0
].
GetLength
(
I2
);
const
index_t
loop_step_m
=
num_threads_m
*
MPerThread
;
const
index_t
loop_step_n
=
num_threads_n
*
NPerThread
;
const
index_t
loop_step_k
=
num_threads_k
*
KPerThread
;
const
index_t
thread_1d_id
=
get_thread_global_1d_id
();
const
index_t
tid_m
=
thread_1d_id
/
(
num_threads_n
*
num_threads_k
);
const
index_t
tid_nk
=
thread_1d_id
%
(
num_threads_n
*
num_threads_k
);
const
index_t
tid_n
=
tid_nk
/
num_threads_k
;
const
index_t
tid_k
=
tid_nk
%
num_threads_k
;
const
auto
thread_global_offset
=
make_multi_index
(
tid_m
*
MPerThread
,
tid_n
*
NPerThread
,
tid_k
*
KPerThread
);
auto
in_global_load_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
in_grid_3d_desc_tuple
[
I
]),
decltype
(
thread_buffer_desc_mnk
),
Sequence
<
MPerThread
,
NPerThread
,
KPerThread
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
>
,
// DimAccessOrder
01
,
// SrcVectorDim
InScalarPerVectorSeq
::
At
(
I
),
// InScalarPerVectorSeq::At(I), //
// ScalarPerVector
1
,
// SrcScalarStrideInVector
true
>
{
in_grid_3d_desc_tuple
[
I
],
thread_global_offset
};
},
Number
<
NumInput
>
{});
auto
out_global_store_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
ThreadwiseTensorSliceTransfer_v1r3
<
DataType
,
DataType
,
decltype
(
thread_buffer_desc_mnk
),
decltype
(
out_grid_3d_desc_tuple
[
I
]),
PassThroughOp
,
Sequence
<
MPerThread
,
NPerThread
,
KPerThread
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
>
,
// DimAccessOrder
2
,
// SrcVectorDim
OutScalarPerVectorSeq
::
At
(
I
),
// OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_3d_desc_tuple
[
I
],
thread_global_offset
,
PassThroughOp
{});
},
Number
<
NumOutput
>
{});
index_t
num_iter_m
=
M
/
(
loop_step_m
);
do
{
index_t
num_iter_n
=
N
/
(
loop_step_n
);
do
{
index_t
num_iter_k
=
K
/
(
loop_step_k
);
do
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
Run
(
in_grid_3d_desc_tuple
[
I
],
in_global_buf_tuple
[
I
],
thread_buffer_desc_mnk
,
make_tuple
(
I0
,
I0
,
I0
),
in_thread_buf_tuple
(
I
));
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_3d_desc_tuple
[
I
],
make_multi_index
(
0
,
0
,
loop_step_k
));
});
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
NPerThread
,
1
>
{}([
&
](
auto
iN
)
{
static_for
<
0
,
KPerThread
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_mnk
.
CalculateOffset
(
make_tuple
(
iM
,
iN
,
iK
));
// get reference to in data
const
auto
in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
const
auto
&
{
return
in_thread_buf_tuple
(
I
)(
Number
<
offset
>
{});
},
Number
<
NumInput
>
{});
// get referenec to dst data
auto
out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
out_thread_buf_tuple
(
I
)(
Number
<
offset
>
{});
},
Number
<
NumOutput
>
{});
unpack2
(
elementwise_op
,
out_data_refs
,
in_data_refs
);
});
});
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
Run
(
thread_buffer_desc_mnk
,
make_tuple
(
I0
,
I0
,
I0
),
out_thread_buf_tuple
[
I
],
out_grid_3d_desc_tuple
[
I
],
out_global_buf_tuple
(
I
));
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_3d_desc_tuple
[
I
],
make_multi_index
(
0
,
0
,
loop_step_k
));
});
}
while
(
--
num_iter_k
);
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_3d_desc_tuple
[
I
],
make_multi_index
(
0
,
loop_step_n
,
-
(
K
/
loop_step_k
)
*
loop_step_k
));
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_3d_desc_tuple
[
I
],
make_multi_index
(
0
,
loop_step_n
,
-
(
K
/
loop_step_k
)
*
loop_step_k
));
});
}
while
(
--
num_iter_n
);
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_3d_desc_tuple
[
I
],
make_multi_index
(
loop_step_m
,
-
(
N
/
loop_step_n
)
*
loop_step_n
,
-
(
K
/
loop_step_k
)
*
loop_step_k
));
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_3d_desc_tuple
[
I
],
make_multi_index
(
loop_step_m
,
-
(
N
/
loop_step_n
)
*
loop_step_n
,
-
(
K
/
loop_step_k
)
*
loop_step_k
));
});
}
while
(
--
num_iter_m
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
defa2071
...
...
@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
Make
Default
AGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
...
...
@@ -219,17 +219,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template
<
typename
AsGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAsGridDescriptor_AK0_M_AK1
(
const
AsGridDesc_M_K
&
as_grid_desc_m_k
)
Make
Default
AsGridDescriptor_AK0_M_AK1
(
const
AsGridDesc_M_K
&
as_grid_desc_m_k
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeAGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
[
i
]);
},
[
&
](
auto
i
)
{
return
Make
Default
AGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
[
i
]);
},
Number
<
NumATensor
>
{});
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
Make
Default
BGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
...
...
@@ -245,10 +245,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template
<
typename
BsGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBsGridDescriptor_BK0_N_BK1
(
const
BsGridDesc_N_K
&
bs_grid_desc_n_k
)
Make
Default
BsGridDescriptor_BK0_N_BK1
(
const
BsGridDesc_N_K
&
bs_grid_desc_n_k
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeBGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
[
i
]);
},
[
&
](
auto
i
)
{
return
Make
Default
BGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
[
i
]);
},
Number
<
NumBTensor
>
{});
}
...
...
@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// return block_id to E matrix tile idx (m0, n0) mapping
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
Make
Default
Block2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
...
...
@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
Number
<
NumATensor
>
{});
static_assert
(
ABlockTransferSrcScalarPerVector
==
ABlockTransferDstScalarPerVector_AK1
,
"Src and Dst ScalarPerVector must be the same"
);
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
AsDataType
,
...
...
@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
Number
<
NumBTensor
>
{});
static_assert
(
BBlockTransferSrcScalarPerVector
==
BBlockTransferDstScalarPerVector_BK1
,
"Src and Dst ScalarPerVector must be the same"
);
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
BsDataType
,
...
...
@@ -1005,9 +1011,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>
(
M
,
N
,
StrideE
);
// tensor descriptors for block/thread-wise copy
const
auto
as_grid_desc_ak0_m_ak1
=
MakeAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
);
const
auto
as_grid_desc_ak0_m_ak1
=
Make
Default
AsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k
);
const
auto
bs_grid_desc_bk0_n_bk1
=
MakeBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
);
const
auto
bs_grid_desc_bk0_n_bk1
=
Make
Default
BsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k
);
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
defa2071
...
...
@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
K0
;
index_t
K0
Padded
;
index_t
k_batch
;
Argument
(
const
FloatA
*
p_a_grid_
,
...
...
@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
K0
Padded
_
,
index_t
k_batch_
)
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
...
...
@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MPadded
(
MPadded_
),
NPadded
(
NPadded_
),
KPadded
(
KPadded_
),
K0
(
K0
_
),
K0
Padded
(
K0Padded
_
),
k_batch
(
k_batch_
)
{
}
...
...
@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"K0:"
<<
K0
<<
", "
<<
"K0
Padded
:"
<<
K0
Padded
<<
", "
<<
"KB:"
<<
k_batch
<<
"}"
<<
std
::
endl
;
}
};
...
...
@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
__device__
static
auto
CalculateK0
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateK0
Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
// k_batch * k0 * k0_per_block * k1
auto
K_t
=
K_Batch
*
K0PerBlock
*
K1
;
...
...
@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K0
=
CalculateK0
(
K
,
K_Batch
);
return
K_Batch
*
K0
*
K1
;
auto
K0
Padded
=
CalculateK0
Padded
(
K
,
K_Batch
);
return
K_Batch
*
K0
Padded
*
K1
;
}
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
...
...
@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
,
index_t
K0
Padded
,
index_t
KPad
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
...
...
@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
pad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
Padded
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
,
index_t
K0
Padded
,
index_t
KPad
)
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
...
...
@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k
pad
_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
Padded
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
...
...
@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
karg
.
k_batch
*
K0PerBlock
*
K1
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
...
...
@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of
CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
"
CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
...
...
@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of
CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
"
CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
const
auto
num_k_loop
=
karg
.
K0
/
K0PerBlock
;
const
auto
num_k_loop
=
karg
.
K0
Padded
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
#if DEBUG_LOG
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
<<
") value is not supported by GridwiseGemm Pipeline."
<<
" K0: "
<<
karg
.
K0
<<
", K0PerBlock: "
<<
K0PerBlock
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
" K0
Padded
: "
<<
karg
.
K0
Padded
<<
", K0PerBlock: "
<<
K0PerBlock
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
...
...
@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
const
index_t
K0Padded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0Padded
*
K1
;
return
KPad
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
Padded
)
{
const
index_t
num_loop
=
K0
/
K0PerBlock
;
const
index_t
num_loop
=
K0
Padded
/
K0PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
...
...
@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
Padded
,
karg
.
KPadded
);
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
Padded
,
karg
.
KPadded
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
defa2071
...
...
@@ -21,6 +21,7 @@ template <typename InputGridDesc,
typename
OutputGridDesc
,
typename
OutputDataType
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseTensorRearrangeKernel
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -30,13 +31,20 @@ __global__ void
const
InputDataType
*
__restrict__
p_in_global
,
const
OutputGridDesc
out_grid_desc
,
OutputDataType
*
__restrict__
p_out_global
,
const
Block2ETileMap
block_2_tile_map
)
const
index_t
batch_count
,
const
Block2ETileMap
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
p_in_global
,
out_grid_desc
,
p_out_global
,
block_2_tile_map
);
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
p_in_global
,
out_grid_desc
,
p_out_global
,
batch_count
,
block_2_tile_map
,
compute_ptr_offset_of_batch
);
#else
ignore
=
in_grid_desc
;
ignore
=
p_in_global
;
...
...
@@ -56,7 +64,8 @@ template <typename InputGridDesc,
typename
ThreadClusterLengths
,
index_t
ScalarPerVector
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
Block2ETileMap
>
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfStridedBatch
>
struct
GridwiseTensorRearrange
{
...
...
@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
const
InputDataType
*
__restrict__
p_in_global
,
const
OutputGridDesc
&
out_grid_desc
,
OutputDataType
*
__restrict__
p_out_global
,
const
Block2ETileMap
&
block_2_tile_map
)
const
index_t
batch_count
,
const
Block2ETileMap
&
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
&
compute_ptr_offset_of_batch
)
{
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
const
index_t
k_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
KPerBlock
);
// Global Memory
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc
.
GetElementSpaceSize
());
auto
copy_global_to_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
Tuple
<
InputDataType
>
,
...
...
@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
make_tuple
(
make_multi_index
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
)),
tensor_operation
::
element_wise
::
PassThrough
{}};
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
// Global Memory
const
index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
));
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
+
a_batch_offset
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
+
c_batch_offset
,
out_grid_desc
.
GetElementSpaceSize
());
copy_global_to_global
.
Run
(
tie
(
in_grid_desc
),
tie
(
in_global_buf
),
tie
(
out_grid_desc
),
tie
(
out_global_buf
));
}
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp
0 → 100644
View file @
defa2071
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace
ck
{
// dgamma = reduce_sum(dy * (x - mean) * inv_std)
// dbeta = reduce_sum(dy)
template
<
typename
DYDataType
,
typename
XDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DGammaDataType
,
typename
DBetaDataType
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
DYSrcVectorDim
,
index_t
DYSrcVectorSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
MeanInvStdSrcVectorDim
,
index_t
MeanInvStdSrcVectorSize
,
index_t
DGammaDstVectorSize
,
index_t
DBetaDstVectorSize
>
struct
GridwiseNormalizationBwdGammaBeta_mk_to_k
{
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
==
DYSrcVectorSize
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
==
DYSrcVectorSize
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
==
XSrcVectorSize
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
==
XSrcVectorSize
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
DYThreadBufferDimAccessOrder
=
typename
conditional
<
DYSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
XThreadBufferDimAccessOrder
=
typename
conditional
<
XSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
MeanInvStdThreadBufferDimAccessOrder
=
typename
conditional
<
MeanInvStdSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
DYThreadBufferDimAccessOrder
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
true
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
dy_grid_desc_m_k
,
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
mean_grid_desc_m_k
,
const
GridDesc_M_K
&
inv_std_grid_desc_m_k
,
const
GridDesc_M
&
dgamma_grid_desc_m
,
const
GridDesc_M
&
dbeta_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DGammaDataType
*
const
__restrict__
p_dgamma_global
,
DBetaDataType
*
const
__restrict__
p_dbeta_global
)
{
// LDS
__shared__
ComputeDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
// Global
const
auto
dy_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy_global
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_inv_std_global
,
inv_std_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dgamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dgamma_global
,
dgamma_grid_desc_m
.
GetElementSpaceSize
());
auto
dbeta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dbeta_global
,
dbeta_grid_desc_m
.
GetElementSpaceSize
());
// VGPR
auto
dy_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
x_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
mean_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
inv_std_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
{};
auto
dgamma_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
auto
dbeta_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
{};
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
// IO
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DYDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
DYThreadBufferDimAccessOrder
,
DYSrcVectorDim
,
DYSrcVectorSize
,
1
,
true
>
(
dy_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
XThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_mean_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
true
>
(
mean_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_inv_std_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanInvStdDataType
,
ComputeDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
MeanInvStdThreadBufferDimAccessOrder
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
1
,
true
>
(
inv_std_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dgamma_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
DGammaDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
DGammaDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dgamma_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_dbeta_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
DBetaDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
DBetaDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dbeta_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
dgamma_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
dbeta_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_mean_load
.
MoveSrcSliceWindow
(
mean_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_inv_std_load
.
MoveSrcSliceWindow
(
inv_std_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
dgamma_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
(
x_thread_buf
[
offset_m_k
]
-
mean_thread_buf
[
offset_m_k
]);
dbeta_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
];
});
});
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
dbeta_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
dgamma_thread_buf
(
I
));
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_dgamma_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dgamma_thread_buf
,
dgamma_grid_desc_m
,
dgamma_global_val_buf
);
threadwise_dbeta_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbeta_thread_buf
,
dbeta_grid_desc_m
,
dbeta_global_val_buf
);
}
}
};
}
// namespace ck
include/ck/utility/data_type.hpp
View file @
defa2071
...
...
@@ -1075,6 +1075,7 @@ struct NumericUtils<float>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
...
...
@@ -1091,6 +1092,7 @@ struct NumericUtils<half_t>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
...
...
@@ -1107,6 +1109,8 @@ struct NumericUtils<f8_t>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
template
<
>
...
...
@@ -1114,6 +1118,7 @@ struct NumericUtils<bf8_t>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
//
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
defa2071
...
...
@@ -5,7 +5,6 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
namespace
ck
{
// fp8 rounding modes
...
...
@@ -17,6 +16,9 @@ enum class f8_rounding_mode
stochastic
};
__host__
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
__device__
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
}
// namespace ck
namespace
ck
::
utils
{
...
...
@@ -34,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
Y
nan_code
=
0x80
;
...
...
@@ -49,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
NumericUtils
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
in_exp
-
1
))
-
(
1
<<
(
out_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
{
...
...
@@ -67,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
in_mant
;
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
in_mant
))
{
mantissa
>>=
1
;
exponent
++
;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
mantissa
>>=
(
in_mant
-
out_mant
);
// check negative exponent
if
(
exponent
<=
0
)
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
if
(
x_bitwise
==
0
)
return
0
;
else
if
((
1
<<
in_mant
)
&
mantissa
)
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
// above range: quantize to maximum possible float of the same sign
else
if
(
exponent
>
max_exp
)
else
{
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
exponent
=
max_exp
;
mantissa
=
(
1
<<
out_mant
)
-
1
;
out_
exponent
=
max_exp
;
}
else
{
...
...
@@ -125,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
}
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
out_
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
out_
exponent
<<
out_mant
)
|
mantissa
;
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
...
...
@@ -194,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent
++
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
--
;
}
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
...
...
include/ck/utility/inner_product.hpp
View file @
defa2071
...
...
@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#endif
#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
c
=
__builtin_amdgcn_sudot4
(
true
,
bit_cast
<
int32_t
>
(
a
),
true
,
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
...
...
include/ck/utility/math.hpp
View file @
defa2071
...
...
@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
// disallow implicit type casting
template
<
typename
T
>
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
static
inline
__host__
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
exp
(
double
x
)
{
return
std
::
exp
(
x
);
}
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
...
...
include/ck/utility/math_v2.hpp
View file @
defa2071
...
...
@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace
ck
{
namespace
math
{
...
...
@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
half_t
tanh
(
half_t
x
)
template
<
typename
T
>
inline
__host__
T
tanh
(
T
x
)
{
return
static_cast
<
half_t
>
(
std
::
tanh
(
static_cas
t
<
float
>
(
x
)));
return
ck
::
type_convert
<
T
>
(
std
::
tanhf
(
ck
::
type_conver
t
<
float
>
(
x
)));
};
static
inline
__host__
float
tanh
(
float
x
)
{
return
std
::
tanh
(
x
);
};
template
<
>
inline
__host__
float
tanh
<
float
>
(
float
x
)
{
return
std
::
tanhf
(
x
);
};
template
<
>
inline
__host__
double
tanh
<
double
>
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
expf
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
tanh
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
>
inline
__host__
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
typename
T
>
inline
__host__
T
log
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
logf
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
log
<
float
>
(
float
x
)
{
return
std
::
logf
(
x
);
}
template
<
>
inline
__host__
double
log
<
double
>
(
double
x
)
{
return
std
::
log
(
x
);
}
template
<
typename
T
>
inline
__host__
T
pow
(
T
x
,
T
gamma
)
{
return
ck
::
type_convert
<
T
>
(
std
::
powf
(
ck
::
type_convert
<
float
>
(
x
),
ck
::
type_convert
<
float
>
(
gamma
)));
}
template
<
>
inline
__host__
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
std
::
powf
(
x
,
gamma
);
}
template
<
>
inline
__host__
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
std
::
pow
(
x
,
gamma
);
}
template
<
typename
T
>
inline
__host__
T
expm1
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
expm1f
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
expm1
<
float
>
(
float
x
)
{
return
std
::
expm1f
(
x
);
}
template
<
>
inline
__host__
double
expm1
<
double
>
(
double
x
)
{
return
std
::
expm1
(
x
);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
...
...
@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
half_t
tanh
(
half_t
x
)
template
<
typename
T
>
inline
__device__
T
tanh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
tanhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
tanh
<
float
>
(
float
x
)
{
return
static_cast
<
half_t
>
(
::
tanhf
(
static_cast
<
float
>
(
x
))
);
return
::
tanhf
(
x
);
};
static
inline
__device__
float
tanh
(
float
x
)
{
return
::
tanhf
(
x
);
};
template
<
>
inline
__device__
double
tanh
<
double
>
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__expf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
half_t
exp
<
half_t
>
(
half_t
x
)
{
return
hexp
(
x
);
};
template
<
>
inline
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
};
static
inline
__device__
double
tanh
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
>
inline
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
};
template
<
typename
T
>
inline
__device__
T
log
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__logf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
half_t
log
<
half_t
>
(
half_t
x
)
{
return
hlog
(
x
);
};
template
<
>
inline
__device__
float
log
<
float
>
(
float
x
)
{
return
__logf
(
x
);
};
template
<
>
inline
__device__
double
log
<
double
>
(
double
x
)
{
return
log
(
x
);
};
template
<
typename
T
>
inline
__device__
T
pow
(
T
x
,
T
gamma
)
{
return
ck
::
type_convert
<
T
>
(
powf
(
ck
::
type_convert
<
float
>
(
x
),
ck
::
type_convert
<
float
>
(
gamma
)));
};
template
<
>
inline
__device__
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
powf
(
x
,
gamma
);
};
template
<
>
inline
__device__
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
pow
(
x
,
gamma
);
};
template
<
typename
T
>
inline
__device__
T
expm1
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
expm1f
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
expm1
<
float
>
(
float
x
)
{
return
expm1f
(
x
);
};
template
<
>
inline
__device__
double
expm1
<
double
>
(
double
x
)
{
return
expm1
(
x
);
};
}
// namespace math
}
// namespace ck
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
defa2071
...
...
@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace
ck
{
...
...
include/ck/utility/type_convert.hpp
View file @
defa2071
...
...
@@ -100,6 +100,8 @@ template <>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
...
...
@@ -138,6 +140,36 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
#endif
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_t
>
(
f8x2_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
constexpr
bool
negative_zero_nan
=
true
;
const
auto
f8x2_v
=
vector_type
<
f8_t
,
2
>
(
x
);
vector_type
<
float
,
2
>
f32x2_v
;
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
0
>
{}]);
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
1
>
{}]);
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
#endif
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
{
const
vector_type
<
float
,
2
>
f32x2_v
(
x
);
const
auto
y
=
__builtin_amdgcn_cvt_pkrtz
(
f32x2_v
.
template
AsType
<
float
>()[
Number
<
0
>
{}],
f32x2_v
.
template
AsType
<
float
>()[
Number
<
1
>
{}]);
return
bit_cast
<
half2_t
>
(
y
);
}
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
...
...
@@ -145,7 +177,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
...
@@ -153,8 +185,6 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#else
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
...
...
@@ -165,11 +195,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
...
...
@@ -223,7 +251,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
...
@@ -231,8 +259,6 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#else
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
...
...
@@ -243,11 +269,9 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
...
...
@@ -347,7 +371,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
...
@@ -356,8 +380,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#else
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
...
...
@@ -396,7 +418,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
#el
se
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
...
@@ -406,8 +428,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#else
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
View file @
defa2071
...
...
@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for column to image.
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
...
...
@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
{
if
(
!
(
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
input_
.
GetNumOfDimension
()
==
2
))
arg
.
input_
.
GetNumOfDimension
()
==
3
))
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
const
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
n
)
{
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Wo
+
wo
;
...
...
@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
3
])
{
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
wi
));
arg
.
output_
(
0
,
n
,
c
,
wi
)
=
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
c
,
wi
));
arg
.
output_
(
g
,
n
,
c
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
column
++
;
...
...
@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
func
,
G
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
...
...
@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
n
)
{
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
...
...
@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg
.
output_
.
GetLengths
()[
4
])
{
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
hi
,
wi
));
arg
.
output_
(
0
,
n
,
c
,
hi
,
wi
)
=
arg
.
output_
(
g
,
n
,
c
,
hi
,
wi
));
arg
.
output_
(
g
,
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
column
++
;
...
...
@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
func
,
G
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
...
...
@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
n
)
{
auto
func
=
[
&
](
auto
g
,
auto
n
)
{
for
(
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
...
...
@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg
.
output_
.
GetLengths
()[
5
])
{
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
arg
.
input_
(
g
,
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
di
,
hi
,
wi
));
arg
.
output_
(
0
,
n
,
c
,
di
,
hi
,
wi
)
=
arg
.
output_
(
g
,
n
,
c
,
di
,
hi
,
wi
));
arg
.
output_
(
g
,
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
column
++
;
...
...
@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
func
,
G
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
...
...
@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
C
*
ck
::
accumulate_n
<
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
input_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
NDoHoWo
)
&&
arg
.
input_
.
GetLengths
()[
1
]
==
static_cast
<
std
::
size_t
>
(
CZYX
)))
if
(
!
(
arg
.
input_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
G
)
&&
arg
.
input_
.
GetLengths
()[
1
]
==
static_cast
<
std
::
size_t
>
(
NDoHoWo
)
&&
arg
.
input_
.
GetLengths
()[
2
]
==
static_cast
<
std
::
size_t
>
(
CZYX
)))
{
return
false
;
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp
View file @
defa2071
...
...
@@ -23,6 +23,7 @@ template <ck::index_t NumDimM,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ComputeDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
ck
::
enable_if_t
<
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
,
bool
>
=
false
>
...
...
@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
for
(
ck
::
index_t
k1
=
0
;
k1
<
K1
;
++
k1
)
{
// Simulate the possible casting when ComputeDataType is different than the
// A/B data types
ComputeDataType
v_a_compute_input
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
));
ComputeDataType
v_b_compute_input
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
k0
,
k1
));
AccDataType
v_a
;
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
)));
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
k0
,
k1
)));
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
AccDataType
>
(
v_a_compute_input
));
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
AccDataType
>
(
v_b_compute_input
));
v_acc
+=
v_a
*
v_b
;
}
}
arg
.
c_ms_ns_
(
m0
,
m1
,
n0
,
n1
)
=
v_acc
;
arg
.
c_ms_ns_
(
m0
,
m1
,
n0
,
n1
)
=
ck
::
type_convert
<
CDataType
>
(
v_acc
)
;
};
make_ParallelTensorFunctor
(
f_ms_ns
,
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment