Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
53dba87a
Commit
53dba87a
authored
Aug 31, 2023
by
Rostyslav Geyyer
Browse files
Add macros to enable build with disabled fp8/bf8
parent
59954f5a
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
50 additions
and
4 deletions
+50
-4
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+4
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+9
-1
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+16
-2
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+8
-0
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+2
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+4
-0
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+3
-1
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+2
-0
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
53dba87a
...
...
@@ -89,6 +89,7 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
...
...
@@ -118,6 +119,7 @@ struct PassThrough
{
y
=
type_convert
<
f8_t
>
(
x
);
}
#endif
};
struct
UnaryConvert
...
...
@@ -146,6 +148,7 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
struct
ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
...
...
@@ -162,6 +165,7 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
}
};
#endif
struct
Scale
{
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
53dba87a
...
...
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
...
...
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
...
...
@@ -640,6 +642,7 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
...
...
@@ -651,6 +654,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
...
@@ -852,7 +856,11 @@ struct XdlopsGemm
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
,
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
||
is_same
<
base_type
,
f8_t
>::
value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
53dba87a
...
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
...
@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
...
@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#endif
}
// buffer_load requires:
...
...
@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
...
...
@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
if
(
dst_thread_element_valid
)
{
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
...
...
@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
}
#endif
}
...
...
include/ck/utility/amd_xdlops.hpp
View file @
53dba87a
...
...
@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
53dba87a
...
...
@@ -12,8 +12,10 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
template
<
typename
T
>
inline
__host__
__device__
constexpr
auto
is_native
()
...
...
@@ -152,6 +154,7 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template
<
>
struct
scalar_type
<
f8_t
>
{
...
...
@@ -165,6 +168,7 @@ struct scalar_type<bf8_t>
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
...
...
@@ -967,6 +971,7 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// f8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
...
...
@@ -982,6 +987,7 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
struct
NumericLimits
...
...
@@ -1029,6 +1035,7 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template
<
>
struct
NumericLimits
<
f8_t
>
{
...
...
@@ -1074,5 +1081,6 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
#endif
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
53dba87a
...
...
@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp"
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
// fp8 rounding modes
...
...
@@ -283,3 +284,4 @@ __host__ __device__ Y cast_from_f8(X x)
}
}
// namespace ck::utils
#endif
include/ck/utility/type_convert.hpp
View file @
53dba87a
...
...
@@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
...
...
@@ -163,6 +164,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
}
#endif
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
...
...
@@ -221,6 +223,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
...
...
@@ -284,5 +287,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
#endif
}
// namespace ck
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
53dba87a
...
...
@@ -17,10 +17,12 @@ namespace instance {
using
F64
=
double
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F8
=
ck
::
f8_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
using
F8
=
ck
::
f8_t
;
#endif
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
...
library/include/ck/library/utility/check_err.hpp
View file @
53dba87a
...
...
@@ -230,6 +230,7 @@ check_err(const Range& out,
return
res
;
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
||
...
...
@@ -276,6 +277,7 @@ check_err(const Range& out,
}
return
res
;
}
#endif
}
// namespace utils
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment