Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2a51cf49
Commit
2a51cf49
authored
Sep 08, 2023
by
Rostyslav Geyyer
Browse files
Decouple fp8 and bf8 flags
parent
c028e416
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
94 additions
and
34 deletions
+94
-34
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-2
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+3
-3
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+8
-8
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+1
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+13
-5
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+6
-2
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+4
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp
...brary/tensor_operation_instance/gpu/gemm_multiply_add.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
.../ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
+2
-2
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+51
-6
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+2
-2
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
2a51cf49
...
@@ -89,7 +89,7 @@ struct PassThrough
...
@@ -89,7 +89,7 @@ struct PassThrough
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
{
...
@@ -148,7 +148,7 @@ struct ConvertBF16RTN
...
@@ -148,7 +148,7 @@ struct ConvertBF16RTN
}
}
};
};
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
struct
ConvertF8SR
{
{
// convert to fp8 using stochastic rounding (SR)
// convert to fp8 using stochastic rounding (SR)
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
2a51cf49
...
@@ -456,7 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
...
@@ -456,7 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
}
};
};
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
{
...
@@ -642,7 +642,7 @@ struct MfmaSelector
...
@@ -642,7 +642,7 @@ struct MfmaSelector
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
{
...
@@ -857,7 +857,7 @@ struct XdlopsGemm
...
@@ -857,7 +857,7 @@ struct XdlopsGemm
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
||
is_same
<
base_type
,
f8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
#endif
#endif
,
,
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
2a51cf49
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
@@ -1139,11 +1139,11 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1139,11 +1139,11 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#endif
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
}
}
#endif
#endif
#else
#else
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
@@ -1156,7 +1156,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1156,7 +1156,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
}
}
#endif
#endif
#endif
#endif
...
@@ -1216,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1216,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
auto
tmp
=
...
@@ -1229,13 +1229,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1229,13 +1229,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
}
}
#endif
#endif
#else
#else
if
(
dst_thread_element_valid
)
if
(
dst_thread_element_valid
)
{
{
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
...
@@ -1248,7 +1248,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1248,7 +1248,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
}
}
#endif
#endif
}
}
...
...
include/ck/utility/amd_xdlops.hpp
View file @
2a51cf49
...
@@ -355,7 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
...
@@ -355,7 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
}
};
};
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
include/ck/utility/data_type.hpp
View file @
2a51cf49
...
@@ -12,8 +12,10 @@ using half_t = _Float16;
...
@@ -12,8 +12,10 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
#endif
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
#endif
...
@@ -146,14 +148,16 @@ struct scalar_type<int4_t>
...
@@ -146,14 +148,16 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_t
>
{
{
using
type
=
f8_t
;
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
scalar_type
<
bf8_t
>
struct
scalar_type
<
bf8_t
>
{
{
...
@@ -963,16 +967,18 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
...
@@ -963,16 +967,18 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// f8
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
...
@@ -1027,7 +1033,7 @@ struct NumericLimits<int4_t>
...
@@ -1027,7 +1033,7 @@ struct NumericLimits<int4_t>
};
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_t
>
{
{
...
@@ -1050,7 +1056,9 @@ struct NumericLimits<f8_t>
...
@@ -1050,7 +1056,9 @@ struct NumericLimits<f8_t>
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
struct
NumericLimits
<
bf8_t
>
struct
NumericLimits
<
bf8_t
>
{
{
...
...
include/ck/utility/type_convert.hpp
View file @
2a51cf49
...
@@ -80,7 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -80,7 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
// convert fp32 to fp8
// convert fp32 to fp8
template
<
>
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
...
@@ -122,7 +122,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
...
@@ -122,7 +122,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
}
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8
// convert fp32 to bf8
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
...
@@ -223,11 +225,11 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
...
@@ -223,11 +225,11 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// Declare a template function for fp8 conversion using SR
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding
// convert fp32 to fp8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
...
@@ -257,7 +259,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -257,7 +259,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
}
}
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding
// convert fp32 to bf8 with stochastic rounding
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
2a51cf49
...
@@ -20,9 +20,12 @@ using F16 = ck::half_t;
...
@@ -20,9 +20,12 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
#endif
#endif
#if defined CK_ENABLE_BF8
using
BF8
=
ck
::
bf8_t
;
#endif
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp
View file @
2a51cf49
...
@@ -45,7 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_
...
@@ -45,7 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_
PassThrough
,
PassThrough
,
MultiplyAdd
>>>&
);
MultiplyAdd
>>>&
);
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
void
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances
(
void
add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row
,
...
@@ -133,7 +133,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -133,7 +133,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
}
}
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
D0DataType
,
float
>
&&
is_same_v
<
D1DataType
,
float
>
&&
is_same_v
<
D0DataType
,
float
>
&&
is_same_v
<
D1DataType
,
float
>
&&
is_same_v
<
EDataType
,
half_t
>
)
is_same_v
<
EDataType
,
half_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
View file @
2a51cf49
...
@@ -57,7 +57,7 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
...
@@ -57,7 +57,7 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
void
add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances
(
void
add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -178,7 +178,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -178,7 +178,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
else
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
else
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
)
{
{
...
...
library/include/ck/library/utility/check_err.hpp
View file @
2a51cf49
...
@@ -230,11 +230,10 @@ check_err(const Range& out,
...
@@ -230,11 +230,10 @@ check_err(const Range& out,
return
res
;
return
res
;
}
}
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
),
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
)),
bool
>
bool
>
check_err
(
const
Range
&
out
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
RefRange
&
ref
,
...
@@ -250,9 +249,55 @@ check_err(const Range& out,
...
@@ -250,9 +249,55 @@ check_err(const Range& out,
}
}
bool
res
{
true
};
bool
res
{
true
};
int
err_count
=
0
;
int
err_count
=
0
;
double
err
=
0
;
double
err
=
0
;
// TODO: This is a hack. We should have proper specialization for bhalf_t data type.
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
#endif
#if defined CK_ENABLE_BF8
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
),
bool
>
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
...
...
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
2a51cf49
...
@@ -214,7 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification,
...
@@ -214,7 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification,
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
// set softer tolerances for fp8
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
CDataType
,
f8_t
>
)
is_same_v
<
CDataType
,
f8_t
>
)
...
@@ -229,7 +229,7 @@ bool profile_gemm_splitk_impl(int do_verification,
...
@@ -229,7 +229,7 @@ bool profile_gemm_splitk_impl(int do_verification,
{
{
#endif
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP8
}
}
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment