Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
de1afb7b
Commit
de1afb7b
authored
Oct 19, 2023
by
Rostyslav Geyyer
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/composable_kernel
into lwpck-977
parents
ce562aa6
f7331c60
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
388 additions
and
229 deletions
+388
-229
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+4
-28
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-13
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+3
-6
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+2
-0
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+0
-22
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+184
-8
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+1
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+9
-15
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
...ry/reference_tensor_operation/cpu/reference_groupnorm.hpp
+48
-27
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
...ry/reference_tensor_operation/cpu/reference_layernorm.hpp
+41
-20
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+2
-6
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+16
-6
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
...ary/tensor_operation_instance/gpu/convolution_forward.hpp
+8
-7
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
.../ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
+7
-7
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+36
-43
library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp
...k/library/tensor_operation_instance/gpu/normalization.hpp
+10
-7
library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp
...ary/tensor_operation_instance/gpu/normalization_swish.hpp
+11
-7
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+0
-4
library/include/ck/library/utility/host_common_util.hpp
library/include/ck/library/utility/host_common_util.hpp
+1
-1
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+3
-2
No files found.
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
de1afb7b
...
@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
...
@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
{
...
@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
...
@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
{
{
...
@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
...
@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
{
{
...
@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
...
@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
{
{
...
@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
...
@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
template
<
typename
base_type
,
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
MPerXdlops
,
...
@@ -792,7 +784,6 @@ struct MfmaSelector
...
@@ -792,7 +784,6 @@ struct MfmaSelector
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
{
...
@@ -804,9 +795,7 @@ struct MfmaSelector
...
@@ -804,9 +795,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
{
...
@@ -818,9 +807,7 @@ struct MfmaSelector
...
@@ -818,9 +807,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
{
...
@@ -832,9 +819,7 @@ struct MfmaSelector
...
@@ -832,9 +819,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
{
...
@@ -846,7 +831,6 @@ struct MfmaSelector
...
@@ -846,7 +831,6 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
}
#endif
static
constexpr
auto
selected_mfma
=
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
...
@@ -1051,18 +1035,10 @@ struct XdlopsGemm
...
@@ -1051,18 +1035,10 @@ struct XdlopsGemm
static_assert
(
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
||
#if defined CK_ENABLE_FP8
is_same
<
base_type
,
bf8_t
>::
value
||
||
is_same
<
base_type
,
f8_t
>::
value
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
#endif
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
),
#if defined CK_ENABLE_BF8
||
is_same
<
base_type
,
bf8_t
>::
value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
)
#endif
,
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/utility/amd_xdlops.hpp
View file @
de1afb7b
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#pragma once
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
...
@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
struct
intrin_mfma_f32_32x32x16f8f8
;
...
@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
...
@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8bf8
;
struct
intrin_mfma_f32_32x32x16bf8bf8
;
...
@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
...
@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8bf8
;
struct
intrin_mfma_f32_32x32x16f8bf8
;
...
@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
...
@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif
#endif
}
}
};
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8f8
;
struct
intrin_mfma_f32_32x32x16bf8f8
;
...
@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
...
@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif
#endif
}
}
};
};
#endif
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
de1afb7b
...
@@ -9,11 +9,9 @@ namespace ck {
...
@@ -9,11 +9,9 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
#endif
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -1123,5 +1121,4 @@ struct NumericUtils<bf8_t>
...
@@ -1123,5 +1121,4 @@ struct NumericUtils<bf8_t>
static
constexpr
int
bias
=
16
;
// negative zero nan mode
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
// static constexpr int bias = 15; // ieee mode
};
};
}
// namespace ck
}
// namespace ck
include/ck/utility/inner_product.hpp
View file @
de1afb7b
...
@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
...
@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
#else
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#endif
#endif
#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
c
=
__builtin_amdgcn_sudot4
(
true
,
bit_cast
<
int32_t
>
(
a
),
true
,
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#else
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
...
...
include/ck/utility/math.hpp
View file @
de1afb7b
...
@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
...
@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
}
// disallow implicit type casting
template
<
typename
T
>
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
static
inline
__host__
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
exp
(
double
x
)
{
return
std
::
exp
(
x
);
}
// greatest common divisor, aka highest common factor
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
{
...
...
include/ck/utility/math_v2.hpp
View file @
de1afb7b
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace
ck
{
namespace
ck
{
namespace
math
{
namespace
math
{
...
@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
...
@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
half_t
tanh
(
half_t
x
)
template
<
typename
T
>
inline
__host__
T
tanh
(
T
x
)
{
{
return
static_cast
<
half_t
>
(
std
::
tanh
(
static_cas
t
<
float
>
(
x
)));
return
ck
::
type_convert
<
T
>
(
std
::
tanhf
(
ck
::
type_conver
t
<
float
>
(
x
)));
};
};
static
inline
__host__
float
tanh
(
float
x
)
{
return
std
::
tanh
(
x
);
};
template
<
>
inline
__host__
float
tanh
<
float
>
(
float
x
)
{
return
std
::
tanhf
(
x
);
};
template
<
>
inline
__host__
double
tanh
<
double
>
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__host__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
expf
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
tanh
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
>
inline
__host__
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
typename
T
>
inline
__host__
T
log
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
logf
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
log
<
float
>
(
float
x
)
{
return
std
::
logf
(
x
);
}
template
<
>
inline
__host__
double
log
<
double
>
(
double
x
)
{
return
std
::
log
(
x
);
}
template
<
typename
T
>
inline
__host__
T
pow
(
T
x
,
T
gamma
)
{
return
ck
::
type_convert
<
T
>
(
std
::
powf
(
ck
::
type_convert
<
float
>
(
x
),
ck
::
type_convert
<
float
>
(
gamma
)));
}
template
<
>
inline
__host__
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
std
::
powf
(
x
,
gamma
);
}
template
<
>
inline
__host__
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
std
::
pow
(
x
,
gamma
);
}
template
<
typename
T
>
inline
__host__
T
expm1
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
std
::
expm1f
(
ck
::
type_convert
<
float
>
(
x
)));
}
template
<
>
inline
__host__
float
expm1
<
float
>
(
float
x
)
{
return
std
::
expm1f
(
x
);
}
template
<
>
inline
__host__
double
expm1
<
double
>
(
double
x
)
{
return
std
::
expm1
(
x
);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
...
@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
...
@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
half_t
tanh
(
half_t
x
)
template
<
typename
T
>
inline
__device__
T
tanh
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
::
tanhf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
tanh
<
float
>
(
float
x
)
{
{
return
static_cast
<
half_t
>
(
::
tanhf
(
static_cast
<
float
>
(
x
))
);
return
::
tanhf
(
x
);
};
};
static
inline
__device__
float
tanh
(
float
x
)
{
return
::
tanhf
(
x
);
};
template
<
>
inline
__device__
double
tanh
<
double
>
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
typename
T
>
inline
__device__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__expf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
half_t
exp
<
half_t
>
(
half_t
x
)
{
return
hexp
(
x
);
};
template
<
>
inline
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
};
static
inline
__device__
double
tanh
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
>
inline
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
};
template
<
typename
T
>
inline
__device__
T
log
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__logf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
half_t
log
<
half_t
>
(
half_t
x
)
{
return
hlog
(
x
);
};
template
<
>
inline
__device__
float
log
<
float
>
(
float
x
)
{
return
__logf
(
x
);
};
template
<
>
inline
__device__
double
log
<
double
>
(
double
x
)
{
return
log
(
x
);
};
template
<
typename
T
>
inline
__device__
T
pow
(
T
x
,
T
gamma
)
{
return
ck
::
type_convert
<
T
>
(
powf
(
ck
::
type_convert
<
float
>
(
x
),
ck
::
type_convert
<
float
>
(
gamma
)));
};
template
<
>
inline
__device__
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
powf
(
x
,
gamma
);
};
template
<
>
inline
__device__
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
pow
(
x
,
gamma
);
};
template
<
typename
T
>
inline
__device__
T
expm1
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
expm1f
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
expm1
<
float
>
(
float
x
)
{
return
expm1f
(
x
);
};
template
<
>
inline
__device__
double
expm1
<
double
>
(
double
x
)
{
return
expm1
(
x
);
};
}
// namespace math
}
// namespace math
}
// namespace ck
}
// namespace ck
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
de1afb7b
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace
ck
{
namespace
ck
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
de1afb7b
...
@@ -128,11 +128,9 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -128,11 +128,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
}
}
float
v_out
;
OutDataType
v_out
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
OutDataType
>
(
v_acc
));
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
g
,
n
,
k
,
wo
)
=
v_out
;
arg
.
output_
(
g
,
n
,
k
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
};
make_ParallelTensorFunctor
(
func
,
make_ParallelTensorFunctor
(
func
,
...
@@ -184,11 +182,9 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -184,11 +182,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
}
}
float
v_out
;
OutDataType
v_out
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
OutDataType
>
(
v_acc
));
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)
=
v_out
;
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
};
make_ParallelTensorFunctor
(
func
,
make_ParallelTensorFunctor
(
func
,
...
@@ -253,11 +249,9 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -253,11 +249,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
}
}
float
v_out
;
OutDataType
v_out
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
OutDataType
>
(
v_acc
));
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
)
=
v_out
;
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
};
make_ParallelTensorFunctor
(
func
,
make_ParallelTensorFunctor
(
func
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
View file @
de1afb7b
...
@@ -20,8 +20,9 @@ template <typename XDataType,
...
@@ -20,8 +20,9 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
SaveMeanInvStdDataType
,
typename
AccElementwiseOperation
>
typename
ComputeDataType
,
typename
YElementwiseOperation
>
struct
ReferenceGroupnorm
:
public
device
::
BaseOperator
struct
ReferenceGroupnorm
:
public
device
::
BaseOperator
{
{
// x = [N, H, W, G, C]
// x = [N, H, W, G, C]
...
@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
BetaDataType
>&
beta
,
const
Tensor
<
BetaDataType
>&
beta
,
Tensor
<
YDataType
>&
y
,
Tensor
<
YDataType
>&
y
,
AccElementwiseOperation
acc_elementwise_op
,
Tensor
<
SaveMeanInvStdDataType
>&
save_mean
,
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std
,
YElementwiseOperation
y_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
lengths
,
Acc
DataType
epsilon
)
Compute
DataType
epsilon
)
:
x_
(
x
),
:
x_
(
x
),
gamma_
(
gamma
),
gamma_
(
gamma
),
beta_
(
beta
),
beta_
(
beta
),
y_
(
y
),
y_
(
y
),
acc_elementwise_op_
(
acc_elementwise_op
),
save_mean_
(
save_mean
),
save_inv_std_
(
save_inv_std
),
y_elementwise_op_
(
y_elementwise_op
),
lengths_
(
lengths
),
lengths_
(
lengths
),
epsilon_
(
epsilon
)
epsilon_
(
epsilon
)
{
{
...
@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
const
Tensor
<
XDataType
>
gamma_
;
const
Tensor
<
XDataType
>
gamma_
;
const
Tensor
<
XDataType
>
beta_
;
const
Tensor
<
XDataType
>
beta_
;
Tensor
<
YDataType
>&
y_
;
Tensor
<
YDataType
>&
y_
;
AccElementwiseOperation
acc_elementwise_op_
;
Tensor
<
SaveMeanInvStdDataType
>&
save_mean_
;
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std_
;
YElementwiseOperation
y_elementwise_op_
;
std
::
vector
<
index_t
>
lengths_
;
std
::
vector
<
index_t
>
lengths_
;
Acc
DataType
epsilon_
;
Compute
DataType
epsilon_
;
};
};
// Invoker
// Invoker
...
@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
int
G
=
arg
.
lengths_
[
3
];
int
G
=
arg
.
lengths_
[
3
];
int
C
=
arg
.
lengths_
[
4
];
int
C
=
arg
.
lengths_
[
4
];
Tensor
<
Acc
DataType
>
mean
({
N
,
G
});
Tensor
<
Compute
DataType
>
mean
({
N
,
G
});
Tensor
<
Acc
DataType
>
var
({
N
,
G
});
Tensor
<
Compute
DataType
>
var
({
N
,
G
});
// Compute mean & var in [H, W, C] by Welford Algorithm
// Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC
// TODO - parallel for each HWC
...
@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
{
{
for
(
int
g
=
0
;
g
<
G
;
++
g
)
for
(
int
g
=
0
;
g
<
G
;
++
g
)
{
{
Acc
DataType
mean_val
=
type_convert
<
Acc
DataType
>
(
0.0
f
);
Compute
DataType
mean_val
=
type_convert
<
Compute
DataType
>
(
0.0
f
);
Acc
DataType
var_val
=
type_convert
<
Acc
DataType
>
(
0.0
f
);
Compute
DataType
var_val
=
type_convert
<
Compute
DataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
int32_t
curr_count
=
0
;
for
(
int
h
=
0
;
h
<
H
;
++
h
)
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
{
...
@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
for
(
int
c
=
0
;
c
<
C
;
++
c
)
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
{
curr_count
++
;
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
x
=
AccDataType
delta
=
x
-
mean_val
;
type_convert
<
ComputeDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
delta
=
x
-
mean_val
;
mean_val
+=
delta
/
curr_count
;
mean_val
+=
delta
/
curr_count
;
Acc
DataType
delta2
=
x
-
mean_val
;
Compute
DataType
delta2
=
x
-
mean_val
;
var_val
+=
delta
*
delta2
;
var_val
+=
delta
*
delta2
;
}
}
}
}
...
@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
mean
(
n
,
g
)
=
mean_val
;
mean
(
n
,
g
)
=
mean_val
;
var
(
n
,
g
)
=
var_val
/
curr_count
;
var
(
n
,
g
)
=
var_val
/
curr_count
;
arg
.
save_mean_
(
n
,
g
)
=
ck
::
type_convert
<
SaveMeanInvStdDataType
>
(
mean
(
n
,
g
));
ComputeDataType
divisor
=
static_cast
<
ComputeDataType
>
(
1
)
/
ck
::
math
::
sqrt
(
var
(
n
,
g
)
+
arg
.
epsilon_
);
arg
.
save_inv_std_
(
n
,
g
)
=
ck
::
type_convert
<
SaveMeanInvStdDataType
>
(
divisor
);
}
}
}
}
...
@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
{
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
{
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
ComputeDataType
x
=
AccDataType
gamma
=
type_convert
<
AccDataType
>
(
arg
.
gamma_
(
g
,
c
));
type_convert
<
ComputeDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
AccDataType
beta
=
type_convert
<
AccDataType
>
(
arg
.
beta_
(
g
,
c
));
ComputeDataType
gamma
=
AccDataType
mean_val
=
type_convert
<
AccDataType
>
(
mean
(
n
,
g
));
type_convert
<
ComputeDataType
>
(
arg
.
gamma_
(
g
,
c
));
AccDataType
var_val
=
type_convert
<
AccDataType
>
(
var
(
n
,
g
));
ComputeDataType
beta
=
AccDataType
y
=
gamma
*
(
x
-
mean_val
)
/
type_convert
<
ComputeDataType
>
(
arg
.
beta_
(
g
,
c
));
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
var_val
)
+
ComputeDataType
mean_val
=
beta
;
type_convert
<
ComputeDataType
>
(
mean
(
n
,
g
));
arg
.
acc_elementwise_op_
(
y
,
y
);
ComputeDataType
var_val
=
type_convert
<
ComputeDataType
>
(
var
(
n
,
g
));
ComputeDataType
y
=
gamma
*
(
x
-
mean_val
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
var_val
)
+
beta
;
arg
.
y_elementwise_op_
(
y
,
y
);
arg
.
y_
(
n
,
h
,
w
,
g
,
c
)
=
type_convert
<
YDataType
>
(
y
);
arg
.
y_
(
n
,
h
,
w
,
g
,
c
)
=
type_convert
<
YDataType
>
(
y
);
}
}
}
}
...
@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
BetaDataType
>&
beta
,
const
Tensor
<
BetaDataType
>&
beta
,
Tensor
<
YDataType
>&
y
,
Tensor
<
YDataType
>&
y
,
AccElementwiseOperation
acc_elementwise_op
,
Tensor
<
SaveMeanInvStdDataType
>&
save_mean
,
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std
,
YElementwiseOperation
y_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
lengths
,
Acc
DataType
epsilon
)
Compute
DataType
epsilon
)
{
{
return
Argument
{
x
,
gamma
,
beta
,
y
,
acc_elementwise_op
,
lengths
,
epsilon
};
return
Argument
{
x
,
gamma
,
beta
,
y
,
save_mean
,
save_inv_std
,
y_elementwise_op
,
lengths
,
epsilon
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
View file @
de1afb7b
...
@@ -20,8 +20,9 @@ template <typename XDataType,
...
@@ -20,8 +20,9 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
SaveMeanInvStdDataType
,
typename
AccElementwiseOperation
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
>
index_t
NumReduceDim
>
struct
ReferenceLayernorm
:
public
device
::
BaseOperator
struct
ReferenceLayernorm
:
public
device
::
BaseOperator
...
@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
BetaDataType
>&
beta_n
,
const
Tensor
<
BetaDataType
>&
beta_n
,
Tensor
<
YDataType
>&
y_m_n
,
Tensor
<
YDataType
>&
y_m_n
,
AccElementwiseOperation
acc_elementwise_op
,
Tensor
<
SaveMeanInvStdDataType
>&
save_mean_m
,
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std_m
,
YElementwiseOperation
y_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
Acc
DataType
epsilon
)
Compute
DataType
epsilon
)
:
x_m_n_
(
x_m_n
),
:
x_m_n_
(
x_m_n
),
gamma_n_
(
gamma_n
),
gamma_n_
(
gamma_n
),
beta_n_
(
beta_n
),
beta_n_
(
beta_n
),
y_m_n_
(
y_m_n
),
y_m_n_
(
y_m_n
),
acc_elementwise_op_
(
acc_elementwise_op
),
save_mean_m_
(
save_mean_m
),
save_inv_std_m_
(
save_inv_std_m
),
y_elementwise_op_
(
y_elementwise_op
),
lengths_
(
lengths
),
lengths_
(
lengths
),
reduceDims_
(
reduceDims
),
reduceDims_
(
reduceDims
),
epsilon_
(
epsilon
)
epsilon_
(
epsilon
)
...
@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator
const
Tensor
<
XDataType
>
gamma_n_
;
const
Tensor
<
XDataType
>
gamma_n_
;
const
Tensor
<
XDataType
>
beta_n_
;
const
Tensor
<
XDataType
>
beta_n_
;
Tensor
<
YDataType
>&
y_m_n_
;
Tensor
<
YDataType
>&
y_m_n_
;
AccElementwiseOperation
acc_elementwise_op_
;
Tensor
<
SaveMeanInvStdDataType
>&
save_mean_m_
;
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std_m_
;
YElementwiseOperation
y_elementwise_op_
;
std
::
vector
<
index_t
>
lengths_
;
std
::
vector
<
index_t
>
lengths_
;
std
::
vector
<
index_t
>
reduceDims_
;
std
::
vector
<
index_t
>
reduceDims_
;
Acc
DataType
epsilon_
;
Compute
DataType
epsilon_
;
};
};
// Invoker
// Invoker
...
@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator
int
M
=
arg
.
lengths_
[
0
];
int
M
=
arg
.
lengths_
[
0
];
int
N
=
arg
.
lengths_
[
1
];
int
N
=
arg
.
lengths_
[
1
];
Tensor
<
Acc
DataType
>
mean
({
M
});
Tensor
<
Compute
DataType
>
mean
({
M
});
Tensor
<
Acc
DataType
>
var
({
M
});
Tensor
<
Compute
DataType
>
var
({
M
});
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
...
@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
auto
x_val
=
ck
::
type_convert
<
Acc
DataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
x_val
=
ck
::
type_convert
<
Compute
DataType
>
(
arg
.
x_m_n_
(
m
,
n
));
mean
(
m
)
+=
x_val
;
mean
(
m
)
+=
x_val
;
var
(
m
)
+=
x_val
*
x_val
;
var
(
m
)
+=
x_val
*
x_val
;
}
}
...
@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
Acc
DataType
divisor
=
Compute
DataType
divisor
=
static_cast
<
Acc
DataType
>
(
1
)
/
ck
::
math
::
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
static_cast
<
Compute
DataType
>
(
1
)
/
ck
::
math
::
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
x_val
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
*
divisor
;
auto
gamma_val
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
gamma_n_
(
n
));
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
auto
beta_val
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
beta_n_
(
n
));
arg
.
acc_elementwise_op_
(
y_val
,
y_val
);
auto
y_val
=
(
x_val
-
mean
(
m
))
*
divisor
;
y_val
=
(
y_val
*
gamma_val
)
+
beta_val
;
arg
.
y_elementwise_op_
(
y_val
,
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
}
}
arg
.
save_mean_m_
(
m
)
=
ck
::
type_convert
<
SaveMeanInvStdDataType
>
(
mean
(
m
));
arg
.
save_inv_std_m_
(
m
)
=
ck
::
type_convert
<
SaveMeanInvStdDataType
>
(
divisor
);
}
}
return
0
;
return
0
;
...
@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
BetaDataType
>&
beta_n
,
const
Tensor
<
BetaDataType
>&
beta_n
,
Tensor
<
YDataType
>&
y_m_n
,
Tensor
<
YDataType
>&
y_m_n
,
AccElementwiseOperation
acc_elementwise_op
,
Tensor
<
SaveMeanInvStdDataType
>&
save_mean_m
,
Tensor
<
SaveMeanInvStdDataType
>&
save_inv_std_m
,
YElementwiseOperation
y_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
std
::
vector
<
index_t
>
reduceDims
,
Acc
DataType
epsilon
)
Compute
DataType
epsilon
)
{
{
return
Argument
{
return
Argument
{
x_m_n
,
x_m_n
,
gamma_n
,
beta_n
,
y_m_n
,
acc_elementwise_op
,
lengths
,
reduceDims
,
epsilon
};
gamma_n
,
beta_n
,
y_m_n
,
save_mean_m
,
save_inv_std_m
,
y_elementwise_op
,
lengths
,
reduceDims
,
epsilon
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
de1afb7b
...
@@ -20,12 +20,8 @@ using F16 = ck::half_t;
...
@@ -20,12 +20,8 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
#if defined CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
BF8
=
ck
::
bf8_t
;
#endif
#if defined CK_ENABLE_BF8
using
BF8
=
ck
::
bf8_t
;
#endif
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
View file @
de1afb7b
...
@@ -240,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -240,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
is_same_v
<
OutLayout
,
NWK
>
)
is_same_v
<
OutLayout
,
NWK
>
)
{
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
...
@@ -267,17 +269,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -267,17 +269,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
}
}
#endif
#endif
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
#ifdef DL_KERNELS
}
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
#endif
#endif
#if defined(DL_KERNELS) && defined(CK_ENABLE_FP32)
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
...
@@ -306,14 +314,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -306,14 +314,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
}
}
#endif
#endif
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
{
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
View file @
de1afb7b
...
@@ -98,30 +98,31 @@ struct DeviceOperationInstanceFactory<
...
@@ -98,30 +98,31 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_INT8
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
View file @
de1afb7b
...
@@ -155,7 +155,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -155,7 +155,7 @@ struct DeviceOperationInstanceFactory<
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
CDataType
,
float
>
)
is_same_v
<
CDataType
,
float
>
&&
is_same_v
<
ComputeType
,
float
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -180,8 +180,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -180,8 +180,8 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
ComputeType
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -206,8 +206,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -206,8 +206,8 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
else
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
C
Data
Type
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
C
ompute
Type
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -230,8 +230,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -230,8 +230,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
C
Data
Type
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
C
ompute
Type
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
de1afb7b
...
@@ -627,8 +627,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -627,8 +627,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
...
@@ -637,9 +637,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -637,9 +637,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
...
@@ -650,8 +649,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -650,8 +649,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
else
if
constexpr
(
is_same_v
<
InLayout
,
NWGC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
if
constexpr
(
is_same_v
<
InLayout
,
NWGC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
NWGK
>
)
is_same_v
<
OutLayout
,
NWGK
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
...
@@ -662,16 +661,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -662,16 +661,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances
(
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -680,7 +678,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -680,7 +678,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#endif
}
}
}
}
else
if
constexpr
(
NumDimSpatial
==
2
)
if
constexpr
(
NumDimSpatial
==
2
)
{
{
if
constexpr
(
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
if
constexpr
(
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
is_same_v
<
OutLayout
,
GNHWK
>
)
...
@@ -698,8 +696,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -698,8 +696,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
...
@@ -710,9 +708,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -710,9 +708,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
...
@@ -723,8 +720,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -723,8 +720,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
else
if
constexpr
(
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
if
constexpr
(
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
is_same_v
<
OutLayout
,
NHWGK
>
)
{
{
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
...
@@ -739,8 +736,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -739,8 +736,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances
(
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances
(
...
@@ -751,9 +748,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -751,9 +748,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
...
@@ -765,7 +761,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -765,7 +761,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#endif
}
}
}
}
else
if
constexpr
(
NumDimSpatial
==
3
)
if
constexpr
(
NumDimSpatial
==
3
)
{
{
if
constexpr
(
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
if
constexpr
(
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
is_same_v
<
OutLayout
,
GNDHWK
>
)
...
@@ -783,8 +779,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -783,8 +779,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances
(
...
@@ -799,9 +795,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -799,9 +795,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
...
@@ -822,8 +817,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -822,8 +817,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
else
if
constexpr
(
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
if
constexpr
(
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
{
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
...
@@ -838,10 +833,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -838,10 +833,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
...
@@ -856,9 +850,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -856,9 +850,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
#ifdef DL_KERNELS
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
...
@@ -879,9 +872,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -879,9 +872,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
bf8_t
>
&&
is_same_v
<
ComputeTypeA
,
bf8_t
>
&&
is_same_v
<
ComputeTypeB
,
f8_t
>
)
is_same_v
<
ComputeTypeB
,
f8_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances
(
op_ptrs
);
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp
View file @
de1afb7b
...
@@ -19,13 +19,13 @@ namespace instance {
...
@@ -19,13 +19,13 @@ namespace instance {
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
// FP16
// FP16
void
add_device_normalization_rank_2_1_f16_instances
(
void
add_device_normalization_rank_2_1_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F
32
,
F
16
,
PassThrough
,
2
,
1
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F
16
,
F
32
,
PassThrough
,
2
,
1
>>>&
);
void
add_device_normalization_rank_4_3_f16_instances
(
void
add_device_normalization_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F
32
,
F
16
,
PassThrough
,
4
,
3
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F
16
,
F
32
,
PassThrough
,
4
,
3
>>>&
);
void
add_device_normalization_rank_5_3_f16_instances
(
void
add_device_normalization_rank_5_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F
32
,
F
16
,
PassThrough
,
5
,
3
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F
16
,
F
32
,
PassThrough
,
5
,
3
>>>&
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
// FP32
// FP32
...
@@ -42,14 +42,15 @@ template <typename XDataType,
...
@@ -42,14 +42,15 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
>
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
F32
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Rank
,
Rank
,
NumReduceDim
>>
NumReduceDim
>>
...
@@ -57,8 +58,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
...
@@ -57,8 +58,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
using
DeviceOp
=
DeviceNormalization
<
XDataType
,
using
DeviceOp
=
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
F32
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
...
@@ -68,7 +69,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
...
@@ -68,7 +69,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
SaveMeanInvStdDataType
,
F32
>
)
{
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
{
{
...
@@ -86,7 +88,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
...
@@ -86,7 +88,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
SaveMeanInvStdDataType
,
F32
>
)
{
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
{
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp
View file @
de1afb7b
...
@@ -19,7 +19,7 @@ namespace instance {
...
@@ -19,7 +19,7 @@ namespace instance {
// FP16
// FP16
void
add_device_normalization_rank_5_3_swish_f16_instances
(
void
add_device_normalization_rank_5_3_swish_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F
32
,
F
16
,
Swish
,
5
,
3
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F
16
,
F
32
,
Swish
,
5
,
3
>>>&
);
// FP32
// FP32
void
add_device_normalization_rank_5_3_swish_f32_instances
(
void
add_device_normalization_rank_5_3_swish_f32_instances
(
...
@@ -27,20 +27,21 @@ void add_device_normalization_rank_5_3_swish_f32_instances(
...
@@ -27,20 +27,21 @@ void add_device_normalization_rank_5_3_swish_f32_instances(
// [x, gamma, beta, y] = [f16, f32, f32, f16]
// [x, gamma, beta, y] = [f16, f32, f32, f16]
void
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
void
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F32
,
F32
,
F
32
,
F
16
,
Swish
,
5
,
3
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F32
,
F32
,
F
16
,
F
32
,
Swish
,
5
,
3
>>>&
);
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
index_t
Rank
,
index_t
Rank
,
index_t
NumReduceDim
>
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
F32
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ck
::
tensor_operation
::
element_wise
::
Swish
,
ck
::
tensor_operation
::
element_wise
::
Swish
,
Rank
,
Rank
,
NumReduceDim
>>
NumReduceDim
>>
...
@@ -48,8 +49,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -48,8 +49,8 @@ struct DeviceOperationInstanceFactory<
using
DeviceOp
=
DeviceNormalization
<
XDataType
,
using
DeviceOp
=
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
F32
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ck
::
tensor_operation
::
element_wise
::
Swish
,
ck
::
tensor_operation
::
element_wise
::
Swish
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
...
@@ -59,7 +60,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -59,7 +60,8 @@ struct DeviceOperationInstanceFactory<
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
SaveMeanInvStdDataType
,
F32
>
)
{
{
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
{
...
@@ -67,7 +69,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -67,7 +69,8 @@ struct DeviceOperationInstanceFactory<
}
}
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
SaveMeanInvStdDataType
,
F32
>
)
{
{
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
{
...
@@ -75,7 +78,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -75,7 +78,8 @@ struct DeviceOperationInstanceFactory<
}
}
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
else
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F16
>
)
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
SaveMeanInvStdDataType
,
F32
>
)
{
{
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
{
...
...
library/include/ck/library/utility/check_err.hpp
View file @
de1afb7b
...
@@ -230,7 +230,6 @@ check_err(const Range& out,
...
@@ -230,7 +230,6 @@ check_err(const Range& out,
return
res
;
return
res
;
}
}
#if defined CK_ENABLE_FP8
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
),
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
),
...
@@ -275,9 +274,7 @@ check_err(const Range& out,
...
@@ -275,9 +274,7 @@ check_err(const Range& out,
}
}
return
res
;
return
res
;
}
}
#endif
#if defined CK_ENABLE_BF8
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
),
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
),
...
@@ -322,7 +319,6 @@ check_err(const Range& out,
...
@@ -322,7 +319,6 @@ check_err(const Range& out,
}
}
return
res
;
return
res
;
}
}
#endif
}
// namespace utils
}
// namespace utils
}
// namespace ck
}
// namespace ck
library/include/ck/library/utility/host_common_util.hpp
View file @
de1afb7b
...
@@ -22,7 +22,7 @@ static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNu
...
@@ -22,7 +22,7 @@ static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNu
std
::
ofstream
outFile
(
fileName
,
std
::
ios
::
binary
);
std
::
ofstream
outFile
(
fileName
,
std
::
ios
::
binary
);
if
(
outFile
)
if
(
outFile
)
{
{
outFile
.
write
(
reinterpret_cast
<
char
*>
(
data
),
dataNumItems
*
sizeof
(
T
));
outFile
.
write
(
reinterpret_cast
<
const
char
*>
(
data
),
dataNumItems
*
sizeof
(
T
));
outFile
.
close
();
outFile
.
close
();
std
::
cout
<<
"Write output to file "
<<
fileName
<<
std
::
endl
;
std
::
cout
<<
"Write output to file "
<<
fileName
<<
std
::
endl
;
}
}
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
de1afb7b
...
@@ -200,10 +200,11 @@ struct GeneratorTensor_3<ck::bf8_t>
...
@@ -200,10 +200,11 @@ struct GeneratorTensor_3<ck::bf8_t>
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_4
struct
GeneratorTensor_4
{
{
std
::
default_random_engine
generator
;
std
::
mt19937
generator
;
std
::
normal_distribution
<
float
>
distribution
;
std
::
normal_distribution
<
float
>
distribution
;
GeneratorTensor_4
(
float
mean
,
float
stddev
)
:
generator
(
1
),
distribution
(
mean
,
stddev
){};
GeneratorTensor_4
(
float
mean
,
float
stddev
,
unsigned
int
seed
=
1
)
:
generator
(
seed
),
distribution
(
mean
,
stddev
){};
template
<
typename
...
Is
>
template
<
typename
...
Is
>
T
operator
()(
Is
...)
T
operator
()(
Is
...)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment