Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9f8ab221
Unverified
Commit
9f8ab221
authored
Oct 19, 2023
by
zjing14
Committed by
GitHub
Oct 19, 2023
Browse files
Merge branch 'develop' into add_int8_wmma_example_instance
parents
755ace59
b4fc4d0b
Changes
490
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
864 additions
and
226 deletions
+864
-226
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+4
-26
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+30
-4
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+0
-4
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+43
-0
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+0
-22
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+184
-8
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+1
-0
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+2
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+36
-17
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+363
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+12
-9
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+9
-15
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+4
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
...ry/reference_tensor_operation/cpu/reference_groupnorm.hpp
+48
-27
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
...erence_tensor_operation/cpu/reference_image_to_column.hpp
+7
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
...ry/reference_tensor_operation/cpu/reference_layernorm.hpp
+41
-20
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+2
-6
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp
...rary/tensor_operation_instance/gpu/batchnorm_backward.hpp
+28
-22
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
...brary/tensor_operation_instance/gpu/batchnorm_forward.hpp
+25
-19
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp
...library/tensor_operation_instance/gpu/batchnorm_infer.hpp
+25
-19
No files found.
include/ck/utility/data_type.hpp
View file @
9f8ab221
...
...
@@ -9,15 +9,9 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8
template
<
>
struct
scalar_type
<
f8_t
>
{
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
scalar_type
<
bf8_t
>
{
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
...
...
@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
struct
NumericLimits
...
...
@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericLimits
<
f8_t
>
{
...
...
@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericLimits
<
bf8_t
>
{
...
...
@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
#endif
template
<
typename
T
>
struct
NumericUtils
...
...
@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t>
using
bitwise_type
=
uint16_t
;
};
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericUtils
<
f8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericUtils
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
};
#endif
//
}
// namespace ck
include/ck/utility/dynamic_buffer.hpp
View file @
9f8ab221
...
...
@@ -140,10 +140,36 @@ struct DynamicBuffer
}
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum
::
Add
)
{
auto
tmp
=
this
->
template
Get
<
X
>(
i
,
is_valid_element
);
this
->
template
Set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// this->template Set<X>(i, is_valid_element, tmp);
auto
tmp
=
this
->
template
Get
<
X
>(
i
,
is_valid_element
);
using
scalar_t
=
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
;
// handle bfloat addition
if
constexpr
(
is_same_v
<
scalar_t
,
bhalf_t
>
)
{
if
constexpr
(
is_scalar_type
<
X
>::
value
)
{
// Scalar type
auto
result
=
type_convert
<
X
>
(
type_convert
<
float
>
(
x
)
+
type_convert
<
float
>
(
tmp
));
this
->
template
Set
<
X
>(
i
,
is_valid_element
,
result
);
}
else
{
// Vector type
constexpr
auto
vector_size
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
const
vector_type
<
scalar_t
,
vector_size
>
a_vector
{
tmp
};
const
vector_type
<
scalar_t
,
vector_size
>
b_vector
{
x
};
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
idx
)
{
auto
result
=
type_convert
<
scalar_t
>
(
type_convert
<
float
>
(
a_vector
.
template
AsType
<
scalar_t
>()[
idx
])
+
type_convert
<
float
>
(
b_vector
.
template
AsType
<
scalar_t
>()[
idx
]));
this
->
template
Set
<
scalar_t
>(
i
+
idx
,
is_valid_element
,
result
);
});
}
}
else
{
this
->
template
Set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
}
}
}
...
...
include/ck/utility/f8_utils.hpp
View file @
9f8ab221
...
...
@@ -6,8 +6,6 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
// fp8 rounding modes
...
...
@@ -244,5 +242,3 @@ __host__ __device__ Y cast_from_f8(X x)
}
}
// namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
include/ck/utility/is_detected.hpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
struct
nonesuch
{
~
nonesuch
()
=
delete
;
nonesuch
(
nonesuch
const
&
)
=
delete
;
void
operator
=
(
nonesuch
const
&
)
=
delete
;
};
template
<
template
<
class
...
>
class
Op
,
class
...
Args
>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
include/ck/utility/math.hpp
View file @
9f8ab221
...
...
@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
// disallow implicit type casting
template
<
typename
T
>
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
static
inline
__host__
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
exp
(
double
x
)
{
return
std
::
exp
(
x
);
}
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
...
...
include/ck/utility/math_v2.hpp
View file @
9f8ab221
...
...
@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace
ck
{
namespace
math
{
...
...
@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
half_t
tanh
(
half_t
x
)
template
<
typename
T
>
inline
__host__
T
tanh
(
T
x
)
{
return
static_cast
<
half_t
>
(
std
::
tanh
(
static_cas
t
<
float
>
(
x
)));
return
ck
::
type_convert
<
T
>
(
std
::
tanhf
(
ck
::
type_conver
t
<
float
>
(
x
)));
};
static
inline
__host__
float
tanh
(
float
x
)
{
return
std
::
tanh
(
x
);
};
template
<
>
inline
__host__
float
tanh
<
float
>
(
float
x
)
{
return
std
::
tanhf
(
x
);
};
template
<
>
inline
__host__
double
tanh
<
double
>
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
expf
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
tanh
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
>
inline
__host__
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
typename
T
>
inline
__host__
T
log
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
logf
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
log
<
float
>
(
float
x
)
{
return
std
::
logf
(
x
);
}
template
<
>
inline
__host__
double
log
<
double
>
(
double
x
)
{
return
std
::
log
(
x
);
}
template
<
typename
T
>
inline
__host__
T
pow
(
T
x
,
T
gamma
)
{
return
ck
::
type_convert
<
T
>
(
std
::
powf
(
ck
::
type_convert
<
float
>
(
x
),
ck
::
type_convert
<
float
>
(
gamma
)));
}
template
<
>
inline
__host__
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
std
::
powf
(
x
,
gamma
);
}
template
<
>
inline
__host__
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
std
::
pow
(
x
,
gamma
);
}
template
<
typename
T
>
inline
__host__
T
expm1
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
expm1f
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
expm1
<
float
>
(
float
x
)
{
return
std
::
expm1f
(
x
);
}
template
<
>
inline
__host__
double
expm1
<
double
>
(
double
x
)
{
return
std
::
expm1
(
x
);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
...
...
@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
half_t
tanh
(
half_t
x
)
template
<
typename
T
>
inline
__device__
T
tanh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
tanhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
tanh
<
float
>
(
float
x
)
{
return
static_cast
<
half_t
>
(
::
tanhf
(
static_cast
<
float
>
(
x
))
);
return
::
tanhf
(
x
);
};
static
inline
__device__
float
tanh
(
float
x
)
{
return
::
tanhf
(
x
);
};
template
<
>
inline
__device__
double
tanh
<
double
>
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__expf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
half_t
exp
<
half_t
>
(
half_t
x
)
{
return
hexp
(
x
);
};
template
<
>
inline
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
};
static
inline
__device__
double
tanh
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
>
inline
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
};
template
<
typename
T
>
inline
__device__
T
log
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__logf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
half_t
log
<
half_t
>
(
half_t
x
)
{
return
hlog
(
x
);
};
template
<
>
inline
__device__
float
log
<
float
>
(
float
x
)
{
return
__logf
(
x
);
};
template
<
>
inline
__device__
double
log
<
double
>
(
double
x
)
{
return
log
(
x
);
};
template
<
typename
T
>
inline
__device__
T
pow
(
T
x
,
T
gamma
)
{
return
ck
::
type_convert
<
T
>
(
powf
(
ck
::
type_convert
<
float
>
(
x
),
ck
::
type_convert
<
float
>
(
gamma
)));
};
template
<
>
inline
__device__
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
powf
(
x
,
gamma
);
};
template
<
>
inline
__device__
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
pow
(
x
,
gamma
);
};
template
<
typename
T
>
inline
__device__
T
expm1
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
expm1f
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
expm1
<
float
>
(
float
x
)
{
return
expm1f
(
x
);
};
template
<
>
inline
__device__
double
expm1
<
double
>
(
double
x
)
{
return
expm1
(
x
);
};
}
// namespace math
}
// namespace ck
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
9f8ab221
...
...
@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace
ck
{
...
...
include/ck/utility/tuple.hpp
View file @
9f8ab221
...
...
@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsTuple
()
{
return
true
;
}
};
template
<
>
...
...
include/ck/utility/type_convert.hpp
View file @
9f8ab221
...
...
@@ -9,8 +9,10 @@
namespace
ck
{
// Convert X to Y
template
<
typename
Y
,
typename
X
>
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
...
...
@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x)
return
static_cast
<
Y
>
(
x
);
}
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
}
// convert bfp16 to fp32
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
...
...
@@ -80,7 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
#if defined CK_ENABLE_FP8
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
...
...
@@ -131,7 +145,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
...
@@ -139,6 +153,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#else
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
...
...
@@ -149,14 +165,14 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
...
...
@@ -206,8 +222,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
se
return
type_convert
<
b
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
...
...
@@ -215,6 +231,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#else
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
...
...
@@ -225,12 +243,13 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
#else
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
#endif
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
...
...
@@ -293,7 +312,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
...
...
@@ -329,7 +347,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
...
@@ -338,11 +356,11 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#else
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
...
...
@@ -378,7 +396,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#el
se
#el
if 0
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
...
...
@@ -388,8 +406,9 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#else
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#endif
}
#endif
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <type_traits>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
/**
* \brief Reference implementation for column to image.
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
* \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type.
*/
template
<
ck
::
index_t
NDimSpatial
,
typename
ImageLayout
,
typename
InDataType
,
typename
OutDataType
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceColumnToImage
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
public:
Argument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
:
input_
{
input
},
output_
{
output
},
conv_strides_
{
conv_filter_strides
},
conv_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
in_right_pads_
{
input_right_pads
},
filter_spatial_lengths_
{
filter_spatial_lengths
}
{
initOutputSpatialLengths
();
}
const
Tensor
<
InDataType
>&
input_
;
Tensor
<
OutDataType
>&
output_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
private:
void
initOutputSpatialLengths
()
{
constexpr
auto
input_offset_to_spatial
=
3
;
for
(
ck
::
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_dilations_
[
i
]
+
1
;
output_spatial_lengths_
.
push_back
(
(
output_
.
GetLengths
()[
i
+
input_offset_to_spatial
]
+
in_left_pads_
[
i
]
+
in_right_pads_
[
i
]
-
x_eff
)
/
conv_strides_
[
i
]
+
1
);
}
}
};
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceColumnToImage
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
if
(
!
(
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
input_
.
GetNumOfDimension
()
==
2
))
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
const
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
if
constexpr
(
NDimSpatial
==
1
)
{
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
n
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Wo
+
wo
;
index_t
column
=
0
;
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
3
])
{
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
wi
));
arg
.
output_
(
0
,
n
,
c
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
column
++
;
}
}
}
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
n
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
output_
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
4
])
{
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
hi
,
wi
));
arg
.
output_
(
0
,
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
column
++
;
}
}
}
}
}
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
index_t
Do
=
arg
.
output_spatial_lengths_
[
0
];
const
index_t
Ho
=
arg
.
output_spatial_lengths_
[
1
];
const
index_t
Wo
=
arg
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
n
)
{
for
(
index_t
d_o
=
0
;
d_o
<
Do
;
++
d_o
)
{
for
(
index_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
for
(
index_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
index_t
column
=
0
;
for
(
index_t
z
=
0
;
z
<
arg
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
index_t
y
=
0
;
y
<
arg
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
index_t
x
=
0
;
x
<
arg
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
for
(
index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
output_
.
GetLengths
()[
3
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
output_
.
GetLengths
()[
4
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
output_
.
GetLengths
()[
5
])
{
float
v_in
=
ck
::
type_convert
<
float
>
(
arg
.
input_
(
row
,
column
));
float
v_out
=
ck
::
type_convert
<
float
>
(
arg
.
output_
(
0
,
n
,
c
,
di
,
hi
,
wi
));
arg
.
output_
(
0
,
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
OutDataType
>
(
v_in
+
v_out
);
}
column
++
;
}
}
}
}
}
}
}
};
make_ParallelTensorFunctor
(
func
,
N
)(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
using
namespace
tensor_layout
::
convolution
;
if
constexpr
(
!
(
std
::
is_same_v
<
ImageLayout
,
GNWC
>
||
std
::
is_same_v
<
ImageLayout
,
GNHWC
>
||
std
::
is_same_v
<
ImageLayout
,
GNDHWC
>
))
{
return
false
;
}
if
constexpr
(
!
(
NDimSpatial
>=
1
&&
NDimSpatial
<=
3
))
{
return
false
;
}
return
true
;
}
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
ck
::
index_t
G
=
arg
.
output_
.
GetLengths
()[
0
];
const
ck
::
index_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
const
ck
::
index_t
C
=
arg
.
output_
.
GetLengths
()[
2
];
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
CZYX
=
C
*
ck
::
accumulate_n
<
index_t
>
(
arg
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
!
(
arg
.
input_
.
GetLengths
()[
0
]
==
static_cast
<
std
::
size_t
>
(
NDoHoWo
)
&&
arg
.
input_
.
GetLengths
()[
1
]
==
static_cast
<
std
::
size_t
>
(
CZYX
)))
{
return
false
;
}
if
(
G
!=
1
)
{
return
false
;
}
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
Tensor
<
OutDataType
>&
output
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
return
Argument
{
input
,
output
,
filter_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceColumnToImage"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
9f8ab221
...
...
@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
ComputeTypeA
=
OutDataType
,
typename
ComputeTypeB
=
InDataType
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdWeight
:
public
device
::
BaseOperator
{
...
...
@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
{
float
v_out
;
float
v_in
;
ComputeTypeA
v_out
;
ComputeTypeB
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
wo
)));
...
...
@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
wi
)));
v_acc
+=
v_out
*
v_in
;
v_acc
+=
type_convert
<
float
>
(
v_out
)
*
type_convert
<
float
>
(
v_in
)
;
}
}
}
...
...
@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
4
])
{
float
v_out
;
float
v_in
;
ComputeTypeA
v_out
;
ComputeTypeB
v_in
;
arg
.
out_element_op_
(
v_out
,
...
...
@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
v_acc
+=
type_convert
<
float
>
(
v_out
)
*
type_convert
<
float
>
(
v_in
)
;
}
}
}
...
...
@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
5
])
{
float
v_out
;
float
v_in
;
ComputeTypeA
v_out
;
ComputeTypeB
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
...
...
@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
v_acc
+=
type_convert
<
float
>
(
v_out
)
*
type_convert
<
float
>
(
v_in
);
}
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
9f8ab221
...
...
@@ -128,11 +128,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
g
,
n
,
k
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
OutDataType
v_out
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
OutDataType
>
(
v_acc
));
arg
.
output_
(
g
,
n
,
k
,
wo
)
=
v_out
;
};
make_ParallelTensorFunctor
(
func
,
...
...
@@ -184,11 +182,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
OutDataType
v_out
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
OutDataType
>
(
v_acc
));
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)
=
v_out
;
};
make_ParallelTensorFunctor
(
func
,
...
...
@@ -253,11 +249,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
OutDataType
v_out
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
OutDataType
>
(
v_acc
));
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
)
=
v_out
;
};
make_ParallelTensorFunctor
(
func
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
9f8ab221
...
...
@@ -21,7 +21,8 @@ template <typename ADataType,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputType
=
ADataType
>
typename
ComputeTypeA
=
ADataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceGemm
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
ComputType
v_a
;
ComputType
v_b
;
Comput
e
Type
A
v_a
;
Comput
e
Type
B
v_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
View file @
9f8ab221
...
...
@@ -20,8 +20,9 @@ template <typename XDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
>
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
>
struct
ReferenceGroupnorm
:
public
device
::
BaseOperator
{
// x = [N, H, W, G, C]
...
...
@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
BetaDataType
>&
beta
,
Tensor
<
YDataType
>&
y
,
AccElementwiseOperation
acc_elementwise_op
,
Tensor
<
SaveMeanInvStdDataType
>&
save_mean
,
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std
,
YElementwiseOperation
y_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
Acc
DataType
epsilon
)
Compute
DataType
epsilon
)
:
x_
(
x
),
gamma_
(
gamma
),
beta_
(
beta
),
y_
(
y
),
acc_elementwise_op_
(
acc_elementwise_op
),
save_mean_
(
save_mean
),
save_inv_std_
(
save_inv_std
),
y_elementwise_op_
(
y_elementwise_op
),
lengths_
(
lengths
),
epsilon_
(
epsilon
)
{
...
...
@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
const
Tensor
<
XDataType
>
gamma_
;
const
Tensor
<
XDataType
>
beta_
;
Tensor
<
YDataType
>&
y_
;
AccElementwiseOperation
acc_elementwise_op_
;
Tensor
<
SaveMeanInvStdDataType
>&
save_mean_
;
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std_
;
YElementwiseOperation
y_elementwise_op_
;
std
::
vector
<
index_t
>
lengths_
;
Acc
DataType
epsilon_
;
Compute
DataType
epsilon_
;
};
// Invoker
...
...
@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
int
G
=
arg
.
lengths_
[
3
];
int
C
=
arg
.
lengths_
[
4
];
Tensor
<
Acc
DataType
>
mean
({
N
,
G
});
Tensor
<
Acc
DataType
>
var
({
N
,
G
});
Tensor
<
Compute
DataType
>
mean
({
N
,
G
});
Tensor
<
Compute
DataType
>
var
({
N
,
G
});
// Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC
...
...
@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
{
for
(
int
g
=
0
;
g
<
G
;
++
g
)
{
Acc
DataType
mean_val
=
type_convert
<
Acc
DataType
>
(
0.0
f
);
Acc
DataType
var_val
=
type_convert
<
Acc
DataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
Compute
DataType
mean_val
=
type_convert
<
Compute
DataType
>
(
0.0
f
);
Compute
DataType
var_val
=
type_convert
<
Compute
DataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
...
...
@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
AccDataType
delta
=
x
-
mean_val
;
ComputeDataType
x
=
type_convert
<
ComputeDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
delta
=
x
-
mean_val
;
mean_val
+=
delta
/
curr_count
;
Acc
DataType
delta2
=
x
-
mean_val
;
Compute
DataType
delta2
=
x
-
mean_val
;
var_val
+=
delta
*
delta2
;
}
}
...
...
@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
mean
(
n
,
g
)
=
mean_val
;
var
(
n
,
g
)
=
var_val
/
curr_count
;
arg
.
save_mean_
(
n
,
g
)
=
ck
::
type_convert
<
SaveMeanInvStdDataType
>
(
mean
(
n
,
g
));
ComputeDataType
divisor
=
static_cast
<
ComputeDataType
>
(
1
)
/
ck
::
math
::
sqrt
(
var
(
n
,
g
)
+
arg
.
epsilon_
);
arg
.
save_inv_std_
(
n
,
g
)
=
ck
::
type_convert
<
SaveMeanInvStdDataType
>
(
divisor
);
}
}
...
...
@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
AccDataType
gamma
=
type_convert
<
AccDataType
>
(
arg
.
gamma_
(
g
,
c
));
AccDataType
beta
=
type_convert
<
AccDataType
>
(
arg
.
beta_
(
g
,
c
));
AccDataType
mean_val
=
type_convert
<
AccDataType
>
(
mean
(
n
,
g
));
AccDataType
var_val
=
type_convert
<
AccDataType
>
(
var
(
n
,
g
));
AccDataType
y
=
gamma
*
(
x
-
mean_val
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
var_val
)
+
beta
;
arg
.
acc_elementwise_op_
(
y
,
y
);
ComputeDataType
x
=
type_convert
<
ComputeDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
gamma
=
type_convert
<
ComputeDataType
>
(
arg
.
gamma_
(
g
,
c
));
ComputeDataType
beta
=
type_convert
<
ComputeDataType
>
(
arg
.
beta_
(
g
,
c
));
ComputeDataType
mean_val
=
type_convert
<
ComputeDataType
>
(
mean
(
n
,
g
));
ComputeDataType
var_val
=
type_convert
<
ComputeDataType
>
(
var
(
n
,
g
));
ComputeDataType
y
=
gamma
*
(
x
-
mean_val
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
var_val
)
+
beta
;
arg
.
y_elementwise_op_
(
y
,
y
);
arg
.
y_
(
n
,
h
,
w
,
g
,
c
)
=
type_convert
<
YDataType
>
(
y
);
}
}
...
...
@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
BetaDataType
>&
beta
,
Tensor
<
YDataType
>&
y
,
AccElementwiseOperation
acc_elementwise_op
,
Tensor
<
SaveMeanInvStdDataType
>&
save_mean
,
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std
,
YElementwiseOperation
y_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
Acc
DataType
epsilon
)
Compute
DataType
epsilon
)
{
return
Argument
{
x
,
gamma
,
beta
,
y
,
acc_elementwise_op
,
lengths
,
epsilon
};
return
Argument
{
x
,
gamma
,
beta
,
y
,
save_mean
,
save_inv_std
,
y_elementwise_op
,
lengths
,
epsilon
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp
View file @
9f8ab221
...
...
@@ -18,16 +18,18 @@ namespace host {
/**
* \brief Reference implementation for image to column.
*
*
T
ensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
*
Input t
ensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
* Output tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam I
nput
Layout I
nput
Layout.
* \tparam I
mage
Layout I
mage
Layout.
* \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type.
*/
template
<
ck
::
index_t
NDimSpatial
,
typename
I
nput
Layout
,
typename
I
mage
Layout
,
typename
InDataType
,
typename
OutDataType
,
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
...
...
@@ -240,8 +242,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{
using
namespace
tensor_layout
::
convolution
;
if
constexpr
(
!
(
std
::
is_same_v
<
I
nput
Layout
,
GNWC
>
||
std
::
is_same_v
<
I
nput
Layout
,
GNHWC
>
||
std
::
is_same_v
<
I
nput
Layout
,
GNDHWC
>
))
if
constexpr
(
!
(
std
::
is_same_v
<
I
mage
Layout
,
GNWC
>
||
std
::
is_same_v
<
I
mage
Layout
,
GNHWC
>
||
std
::
is_same_v
<
I
mage
Layout
,
GNDHWC
>
))
{
return
false
;
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
View file @
9f8ab221
...
...
@@ -20,8 +20,9 @@ template <typename XDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
ReferenceLayernorm
:
public
device
::
BaseOperator
...
...
@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
BetaDataType
>&
beta_n
,
Tensor
<
YDataType
>&
y_m_n
,
AccElementwiseOperation
acc_elementwise_op
,
Tensor
<
SaveMeanInvStdDataType
>&
save_mean_m
,
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std_m
,
YElementwiseOperation
y_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
reduceDims
,
Acc
DataType
epsilon
)
Compute
DataType
epsilon
)
:
x_m_n_
(
x_m_n
),
gamma_n_
(
gamma_n
),
beta_n_
(
beta_n
),
y_m_n_
(
y_m_n
),
acc_elementwise_op_
(
acc_elementwise_op
),
save_mean_m_
(
save_mean_m
),
save_inv_std_m_
(
save_inv_std_m
),
y_elementwise_op_
(
y_elementwise_op
),
lengths_
(
lengths
),
reduceDims_
(
reduceDims
),
epsilon_
(
epsilon
)
...
...
@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator
const
Tensor
<
XDataType
>
gamma_n_
;
const
Tensor
<
XDataType
>
beta_n_
;
Tensor
<
YDataType
>&
y_m_n_
;
AccElementwiseOperation
acc_elementwise_op_
;
Tensor
<
SaveMeanInvStdDataType
>&
save_mean_m_
;
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std_m_
;
YElementwiseOperation
y_elementwise_op_
;
std
::
vector
<
index_t
>
lengths_
;
std
::
vector
<
index_t
>
reduceDims_
;
Acc
DataType
epsilon_
;
Compute
DataType
epsilon_
;
};
// Invoker
...
...
@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator
int
M
=
arg
.
lengths_
[
0
];
int
N
=
arg
.
lengths_
[
1
];
Tensor
<
Acc
DataType
>
mean
({
M
});
Tensor
<
Acc
DataType
>
var
({
M
});
Tensor
<
Compute
DataType
>
mean
({
M
});
Tensor
<
Compute
DataType
>
var
({
M
});
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
...
...
@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
auto
x_val
=
ck
::
type_convert
<
Acc
DataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
x_val
=
ck
::
type_convert
<
Compute
DataType
>
(
arg
.
x_m_n_
(
m
,
n
));
mean
(
m
)
+=
x_val
;
var
(
m
)
+=
x_val
*
x_val
;
}
...
...
@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
Acc
DataType
divisor
=
static_cast
<
Acc
DataType
>
(
1
)
/
ck
::
math
::
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
Compute
DataType
divisor
=
static_cast
<
Compute
DataType
>
(
1
)
/
ck
::
math
::
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
*
divisor
;
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
arg
.
acc_elementwise_op_
(
y_val
,
y_val
);
auto
x_val
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
gamma_val
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
gamma_n_
(
n
));
auto
beta_val
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
beta_n_
(
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
*
divisor
;
y_val
=
(
y_val
*
gamma_val
)
+
beta_val
;
arg
.
y_elementwise_op_
(
y_val
,
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
}
arg
.
save_mean_m_
(
m
)
=
ck
::
type_convert
<
SaveMeanInvStdDataType
>
(
mean
(
m
));
arg
.
save_inv_std_m_
(
m
)
=
ck
::
type_convert
<
SaveMeanInvStdDataType
>
(
divisor
);
}
return
0
;
...
...
@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
BetaDataType
>&
beta_n
,
Tensor
<
YDataType
>&
y_m_n
,
AccElementwiseOperation
acc_elementwise_op
,
Tensor
<
SaveMeanInvStdDataType
>&
save_mean_m
,
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std_m
,
YElementwiseOperation
y_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
reduceDims
,
Acc
DataType
epsilon
)
Compute
DataType
epsilon
)
{
return
Argument
{
x_m_n
,
gamma_n
,
beta_n
,
y_m_n
,
acc_elementwise_op
,
lengths
,
reduceDims
,
epsilon
};
return
Argument
{
x_m_n
,
gamma_n
,
beta_n
,
y_m_n
,
save_mean_m
,
save_inv_std_m
,
y_elementwise_op
,
lengths
,
reduceDims
,
epsilon
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
9f8ab221
...
...
@@ -20,12 +20,8 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
#if defined CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
#endif
#if defined CK_ENABLE_BF8
using
BF8
=
ck
::
bf8_t
;
#endif
using
F8
=
ck
::
f8_t
;
using
BF8
=
ck
::
bf8_t
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp
View file @
9f8ab221
...
...
@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
//
FP16
#ifdef CK_ENABLE_
FP16
void
add_device_batchnorm_backward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormBwd
<
F16
,
F32
,
F32
,
F32
,
F16
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
//
FP32
#endif
#ifdef CK_ENABLE_
FP32
void
add_device_batchnorm_backward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormBwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
//
BF16
#endif
#ifdef CK_ENABLE_
BF16
void
add_device_batchnorm_backward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormBwd
<
BF16
,
F32
,
F32
,
F32
,
BF16
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
//
FP64
#endif
#ifdef CK_ENABLE_
FP64
void
add_device_batchnorm_backward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormBwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
);
#endif
template
<
typename
XDataType
,
typename
DxDataType
,
typename
DyDataType
,
...
...
@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
DxDataType
,
F32
>
&&
is_same_v
<
DyDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F16
>
&&
is_same_v
<
DscaleDbiasDataType
,
F32
>
&&
...
...
@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_backward_rank_4_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
DxDataType
,
F32
>
&&
is_same_v
<
DyDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
DscaleDbiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
#endif
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
DxDataType
,
F32
>
&&
is_same_v
<
DyDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
DscaleDbiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
DyElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_backward_rank_4_3_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
DxDataType
,
F32
>
&&
is_same_v
<
DyDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
DscaleDbiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
DxDataType
,
F32
>
&&
is_same_v
<
DyDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
DscaleDbiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
DyElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
DxDataType
,
F64
>
&&
is_same_v
<
DyDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
DscaleDbiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
#endif
#ifdef CK_ENABLE_FP64
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
DxDataType
,
F64
>
&&
is_same_v
<
DyDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
DscaleDbiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
DyElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_backward_rank_4_3_f64_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
View file @
9f8ab221
...
...
@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
//
FP16
#ifdef CK_ENABLE_
FP16
void
add_device_batchnorm_forward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
//
FP32
#endif
#ifdef CK_ENABLE_
FP32
void
add_device_batchnorm_forward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
//
BF16
#endif
#ifdef CK_ENABLE_
BF16
void
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
//
FP64
#endif
#ifdef CK_ENABLE_
FP64
void
add_device_batchnorm_forward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
);
#endif
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
...
...
@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F16
>
&&
is_same_v
<
BiasDataType
,
F16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
...
...
@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_forward_rank_4_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
BiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
#endif
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
BiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
YDataType
,
BF16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
BiasDataType
,
BF16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
YDataType
,
BF16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
BiasDataType
,
BF16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
YDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
BiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
#endif
#ifdef CK_ENABLE_FP64
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
YDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
BiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f64_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp
View file @
9f8ab221
...
...
@@ -16,38 +16,38 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
//
FP16
#ifdef CK_ENABLE_
FP16
void
add_device_batchnorm_infer_rank_4_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
F16
,
F32
,
F32
,
F16
,
F16
>
,
ck
::
Tuple
<
F16
>
,
ck
::
tensor_operation
::
element_wise
::
NormalizeInInfer
,
4
>>>&
);
//
FP32
#endif
#ifdef CK_ENABLE_
FP32
void
add_device_batchnorm_infer_rank_4_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
ck
::
Tuple
<
F32
>
,
ck
::
tensor_operation
::
element_wise
::
NormalizeInInfer
,
4
>>>&
);
//
BF16
#endif
#ifdef CK_ENABLE_
BF16
void
add_device_batchnorm_infer_rank_4_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
BF16
,
F32
,
F32
,
BF16
,
BF16
>
,
ck
::
Tuple
<
BF16
>
,
ck
::
tensor_operation
::
element_wise
::
NormalizeInInfer
,
4
>>>&
);
//
FP64
#endif
#ifdef CK_ENABLE_
FP64
void
add_device_batchnorm_infer_rank_4_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
F64
,
F64
,
F64
,
F64
,
F64
>
,
ck
::
Tuple
<
F64
>
,
ck
::
tensor_operation
::
element_wise
::
NormalizeInInfer
,
4
>>>&
);
#endif
template
<
typename
XDataType
,
typename
YDataType
,
typename
ScaleDataType
,
...
...
@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
ScaleDataType
,
F16
>
&&
is_same_v
<
BiasDataType
,
F16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
...
...
@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
add_device_batchnorm_infer_rank_4_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
BiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
#endif
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
BiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
)
{
add_device_batchnorm_infer_rank_4_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
YDataType
,
BF16
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
BiasDataType
,
BF16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
YDataType
,
BF16
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
BiasDataType
,
BF16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
)
{
add_device_batchnorm_infer_rank_4_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
YDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
BiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
#endif
#ifdef CK_ENABLE_FP64
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
YDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
BiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
{
if
constexpr
(
Rank
==
4
)
{
add_device_batchnorm_infer_rank_4_f64_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment