Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f23a2e2a
Commit
f23a2e2a
authored
Feb 11, 2025
by
Jakub Piasecki
Browse files
resolved conflicts
parents
f3eb5a18
c0adab48
Changes
340
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3135 additions
and
442 deletions
+3135
-442
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+10
-4
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+316
-49
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+1372
-18
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
+293
-10
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+19
-3
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+593
-340
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+6
-5
include/ck_tile/core/numeric/numeric.hpp
include/ck_tile/core/numeric/numeric.hpp
+2
-1
include/ck_tile/core/numeric/pk_int4.hpp
include/ck_tile/core/numeric/pk_int4.hpp
+140
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+18
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+2
-0
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+13
-7
include/ck_tile/host/concat.hpp
include/ck_tile/host/concat.hpp
+122
-0
include/ck_tile/host/reference/reference_batched_transpose.hpp
...de/ck_tile/host/reference/reference_batched_transpose.hpp
+59
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+3
-2
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-0
include/ck_tile/ops/batched_transpose.hpp
include/ck_tile/ops/batched_transpose.hpp
+12
-0
include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp
...ops/batched_transpose/kernel/batched_transpose_kernel.hpp
+129
-0
No files found.
include/ck/utility/tuple_helper.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "functional4.hpp"
#include "functional4.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp"
#include "is_detected.hpp"
#endif
namespace
ck
{
namespace
ck
{
...
@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
...
@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const
Tuple
<
Y
&
...
>&
ty
)
const
Tuple
<
Y
&
...
>&
ty
)
{
{
return
unpack2
(
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
tx
,
ty
);
ty
);
}
}
...
@@ -38,7 +40,7 @@ template <typename... X, typename... Y>
...
@@ -38,7 +40,7 @@ template <typename... X, typename... Y>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
{
return
unpack2
(
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
tx
,
ty
);
ty
);
}
}
...
@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
...
@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
}
}
}
}
#ifndef CK_CODE_GEN_RTC
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#endif
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
IsNestedTuple
(
const
Tuple
<
Ts
...
>&
)
__host__
__device__
constexpr
auto
IsNestedTuple
(
const
Tuple
<
Ts
...
>&
)
{
{
#ifndef CK_CODE_GEN_RTC
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
#endif
}
}
template
<
index_t
depth
=
0
,
typename
T
>
template
<
index_t
depth
=
0
,
typename
T
>
...
...
include/ck/utility/type.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
namespace
ck
{
#ifdef CK_CODE_GEN_RTC
template
<
typename
X
,
typename
Y
>
// NOLINTNEXTLINE
struct
is_same
:
public
integral_constant
<
bool
,
false
>
#define CK_BUILTIN_TYPE_TRAIT1(name) \
{
template
<
class
T
>
\
};
struct
name
:
bool_constant
<
__
##
name
(
T
)
>
\
{
\
template
<
typename
X
>
}
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
// NOLINTNEXTLINE
};
#define CK_BUILTIN_TYPE_TRAIT2(name) \
template
<
class
T
,
class
U
>
\
template
<
typename
X
,
typename
Y
>
struct
name
:
bool_constant
<
__
##
name
(
T
,
U
)
>
\
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
{
\
}
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
type
;
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template
<
typename
T
>
template
<
class
...
Ts
>
\
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
struct
name
:
bool_constant
<
__
##
name
(
Ts
...)
>
\
{
\
template
<
typename
T
>
}
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
CK_BUILTIN_TYPE_TRAIT1
(
is_class
);
template
<
typename
T
>
CK_BUILTIN_TYPE_TRAIT1
(
is_pointer
);
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
CK_BUILTIN_TYPE_TRAIT1
(
is_reference
);
CK_BUILTIN_TYPE_TRAIT1
(
is_trivially_copyable
);
template
<
typename
T
>
CK_BUILTIN_TYPE_TRAIT1
(
is_unsigned
);
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
CK_BUILTIN_TYPE_TRAIT2
(
is_base_of
);
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
template
<
class
T
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
struct
remove_cv
{
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
using
type
=
T
;
static_assert
(
sizeof
(
X
)
==
sizeof
(
Y
),
"Do not support cast between different size of type"
);
};
return
__builtin_bit_cast
(
Y
,
x
);
template
<
class
T
>
}
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
}
// namespace ck
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_reference
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
volatile
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
volatile
>
{
typedef
T
type
;
};
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
class
T
>
struct
is_const
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
class
T
>
struct
is_const
<
const
T
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
class
T
>
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
template
<
class
T
>
struct
remove_const
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
template
<
class
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
class
T
>
inline
constexpr
bool
is_class_v
=
is_class
<
T
>::
value
;
template
<
class
T
>
inline
constexpr
bool
is_trivially_copyable_v
=
is_trivially_copyable
<
T
>::
value
;
// template <typename T>
// T&& declval() noexcept;
template
<
class
T
,
class
U
=
T
&&
>
U
private_declval
(
int
);
template
<
class
T
>
T
private_declval
(
long
);
template
<
class
T
>
auto
declval
()
noexcept
->
decltype
(
private_declval
<
T
>
(
0
));
template
<
class
...
>
using
void_t
=
void
;
#else
#include <utility>
#include <type_traits>
using
std
::
declval
;
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_class
;
using
std
::
is_class_v
;
using
std
::
is_const_v
;
using
std
::
is_pointer
;
using
std
::
is_reference
;
using
std
::
is_reference_v
;
using
std
::
is_trivially_copyable
;
using
std
::
is_trivially_copyable_v
;
using
std
::
is_unsigned
;
using
std
::
remove_const_t
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
void_t
;
#endif
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
typename
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
>
struct
is_floating_point
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
>
struct
is_floating_point
<
float
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_floating_point
<
double
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_floating_point
<
long
double
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
>
struct
is_integral
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
>
struct
is_integral
<
int
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
int
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
short
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
short
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
long
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
long
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
signed
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
wchar_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char16_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char32_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
bool
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_unsigned_v
=
is_unsigned
<
T
>::
value
;
template
<
typename
T
>
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
is_pointer
<
T
>::
value
;
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
static_assert
(
sizeof
(
X
)
==
sizeof
(
Y
),
"Do not support cast between different size of type"
);
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
f23a2e2a
...
@@ -5,15 +5,39 @@
...
@@ -5,15 +5,39 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
namespace
ck
{
// Define the common macro for MI300 models
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#define __gfx94__
#endif
#endif
namespace
{
namespace
details
{
[[
maybe_unused
]]
__host__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
half2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
[[
maybe_unused
]]
__device__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
return
amd_assembly_pk_add_f16
(
x
,
y
);
}
}
// namespace details
}
// namespace
// Declare a template function for bf16 conversion using RTN
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
...
@@ -52,10 +76,10 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
...
@@ -52,10 +76,10 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
// Convert X to Y, both X and Y are non-const data types.
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
template
<
typename
Y
,
typename
X
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
...
@@ -63,13 +87,13 @@ __host__ __device__ constexpr Y type_convert(X x)
...
@@ -63,13 +87,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
template
<
typename
Y
,
typename
X
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
ck
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstY
=
ck
::
remove_const_t
<
Y
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
using
NonConstX
=
ck
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
}
}
...
@@ -149,7 +173,7 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
...
@@ -149,7 +173,7 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
{
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
...
@@ -211,7 +235,11 @@ template <>
...
@@ -211,7 +235,11 @@ template <>
inline
__host__
__device__
f8_fnuz_t
f8_convert_sr
<
f8_fnuz_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_fnuz_t
f8_convert_sr
<
f8_fnuz_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -251,7 +279,12 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
...
@@ -251,7 +279,12 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
f8_fnuz_t
,
negative_zero_nan
,
negative_zero_nan
,
...
@@ -265,7 +298,11 @@ template <>
...
@@ -265,7 +298,11 @@ template <>
inline
__host__
__device__
bf8_fnuz_t
f8_convert_sr
<
bf8_fnuz_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_fnuz_t
f8_convert_sr
<
bf8_fnuz_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -307,7 +344,12 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
...
@@ -307,7 +344,12 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
bf8_fnuz_t
,
negative_zero_nan
,
negative_zero_nan
,
...
@@ -502,13 +544,51 @@ template <>
...
@@ -502,13 +544,51 @@ template <>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
auto
l_f32
=
ck
::
type_convert
<
float
>
(
x_l
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
auto
h_f32
=
ck
::
type_convert
<
float
>
(
x_h
);
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t
res
=
{
x_h
,
x_l
};
#elif
float2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#else
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
details
::
pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_h
),
type_convert
<
bhalf_t
>
(
x_l
)};
#else
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_l
),
type_convert
<
bhalf_t
>
(
x_h
)};
#endif
return
{
l_f32
,
h_f32
}
;
return
res
;
}
}
template
<
>
template
<
>
...
@@ -629,20 +709,1294 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
...
@@ -629,20 +709,1294 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif
#endif
}
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
// convert fp32 to fp4 with rounding to nearest even
inline
__host__
__device__
f4_t
f4_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
,
x
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f4_t
>
(
x
/
scale
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline
__host__
__device__
f4x2_t
f4_convert_rne
(
float2_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline
__host__
__device__
f4x32_t
f4_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{},
tmp_values
{};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
2
],
x
[
3
],
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
4
],
x
[
5
],
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
6
],
x
[
7
],
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
8
],
x
[
9
],
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
10
],
x
[
11
],
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
12
],
x
[
13
],
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
14
],
x
[
15
],
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
16
],
x
[
17
],
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
18
],
x
[
19
],
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
20
],
x
[
21
],
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
22
],
x
[
23
],
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
24
],
x
[
25
],
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
26
],
x
[
27
],
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
28
],
x
[
29
],
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
30
],
x
[
31
],
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
2
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
3
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
4
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
5
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
6
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
7
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
8
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
9
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
10
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
11
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
12
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
13
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
14
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
15
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
16
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
17
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
18
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
19
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
20
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
21
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
22
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
23
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
24
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
25
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
26
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
27
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
28
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
29
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
30
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
31
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{{
x
}};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
float2_array
,
rng
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
/
scale
,
rng
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline
__host__
__device__
f4x2_t
f4_convert_sr
(
float2_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
x
,
rng
,
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
},
tmp_values
{
0
};
union
{
float2_t
floatx2_array
[
16
];
float32_t
floatx32_array
;
}
float_values
{{
0
}};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
0
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
1
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
2
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
3
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
4
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
5
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
6
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
7
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
8
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
9
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
10
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
11
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
12
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
13
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
14
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
15
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
2
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
3
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
4
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
5
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
6
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
7
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
8
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
9
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
10
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
11
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
12
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
13
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
14
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
15
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
16
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
17
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
18
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
19
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
20
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
21
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
22
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
23
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
24
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
25
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
26
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
27
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
28
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
29
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
30
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
31
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
type_convert
<
f4_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
type_convert
<
f4x2_t
,
float2_t
>
(
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
type_convert
<
f4x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float
scale
=
1.0
f
;
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f4x2_t
>
(
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f4x32_t
>
(
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
float
scale
=
1.0
f
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
#endif
}
/**
* @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts it
* to the 6-bit floating-point format (f6_t).
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
in1
,
in2
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
*
* Divides the input by the specified scale, then performs saturation and conversion
* to f6_t based on a pseudo-randomly generated seed.
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6_t value.
*/
template
<
>
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template
<
>
inline
__host__
__device__
f6x32_t
type_convert
<
f6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
*
* Interprets an f6_t value as a float using the default scale factor of 1.
*
* @param x The 6-bit float (f6_t) value to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
{
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f6x32_t
>
(
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
in1
,
in2
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using stochastic rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
* @return Converted bf6_t value.
*/
template
<
>
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template
<
>
inline
__host__
__device__
bf6x32_t
type_convert
<
bf6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
* Interprets the bf6_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6_t value to convert.
* @return The float representation of the given bf6_t value.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
{
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
bf6x32_t
>
(
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
#ifndef CK_CODE_GEN_RTC
template
<
typename
Y
,
typename
X
,
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
const
std
::
array
<
X
,
NumElems
>&
x
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
}
}
#endif
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
...
...
include/ck_tile/core.hpp
View file @
f23a2e2a
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
...
...
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
...
@@ -8,16 +8,75 @@
...
@@ -8,16 +8,75 @@
namespace
ck_tile
{
namespace
ck_tile
{
CK_TILE_HOST_DEVICE
bf16_t
add_bf16_t
(
const
bf16_t
&
a
,
const
bf16_t
&
b
)
template
<
typename
T
,
typename
ComputeType
>
CK_TILE_HOST_DEVICE
T
add
(
const
T
&
a
,
const
T
&
b
)
{
{
return
type_convert
<
bf16_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
return
type_convert
<
T
>
(
type_convert
<
ComputeType
>
(
a
)
+
type_convert
<
ComputeType
>
(
b
));
}
}
CK_TILE_HOST_DEVICE
bf16x2_t
add_bf16x2_t
(
const
bf16x2_t
&
a
,
const
bf16x2_t
&
b
)
CK_TILE_HOST_DEVICE
bf16x2_t
add_bf16x2_t
(
const
bf16x2_t
&
a
,
const
bf16x2_t
&
b
)
{
{
bf16x2_t
rtn
;
bf16x2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
0
]
=
add
<
bf16_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
rtn
[
1
]
=
add
<
bf16_t
,
float
>
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf16x4_t
add_bf16x4_t
(
const
bf16x4_t
&
a
,
const
bf16x4_t
&
b
)
{
bf16x4_t
rtn
;
rtn
[
0
]
=
add
<
bf16_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf16_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf16_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf16_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
fp8x4_t
add_fp8x4_t
(
const
fp8x4_t
&
a
,
const
fp8x4_t
&
b
)
{
fp8x4_t
rtn
;
rtn
[
0
]
=
add
<
fp8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
fp8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
fp8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
fp8_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
fp8x8_t
add_fp8x8_t
(
const
fp8x8_t
&
a
,
const
fp8x8_t
&
b
)
{
fp8x8_t
rtn
;
rtn
[
0
]
=
add
<
fp8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
fp8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
fp8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
fp8_t
,
float
>
(
a
[
3
],
b
[
3
]);
rtn
[
4
]
=
add
<
fp8_t
,
float
>
(
a
[
4
],
b
[
4
]);
rtn
[
5
]
=
add
<
fp8_t
,
float
>
(
a
[
5
],
b
[
5
]);
rtn
[
6
]
=
add
<
fp8_t
,
float
>
(
a
[
6
],
b
[
6
]);
rtn
[
7
]
=
add
<
fp8_t
,
float
>
(
a
[
7
],
b
[
7
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf8x4_t
add_bf8x4_t
(
const
bf8x4_t
&
a
,
const
bf8x4_t
&
b
)
{
bf8x4_t
rtn
;
rtn
[
0
]
=
add
<
bf8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf8_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf8x8_t
add_bf8x8_t
(
const
bf8x8_t
&
a
,
const
bf8x8_t
&
b
)
{
bf8x8_t
rtn
;
rtn
[
0
]
=
add
<
bf8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf8_t
,
float
>
(
a
[
3
],
b
[
3
]);
rtn
[
4
]
=
add
<
bf8_t
,
float
>
(
a
[
4
],
b
[
4
]);
rtn
[
5
]
=
add
<
bf8_t
,
float
>
(
a
[
5
],
b
[
5
]);
rtn
[
6
]
=
add
<
bf8_t
,
float
>
(
a
[
6
],
b
[
6
]);
rtn
[
7
]
=
add
<
bf8_t
,
float
>
(
a
[
7
],
b
[
7
]);
return
rtn
;
return
rtn
;
}
}
...
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
...
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
}
while
(
cur_v
.
u32
!=
old_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf16x4_t
>
(
bf16x4_t
*
p_dst
,
bf16x4_t
const
&
x
)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union
U64BF164_ADDR
{
uint64_t
*
u64_a
;
bf16x4_t
*
bf164_a
;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union
U64BF164
{
uint64_t
u64
;
bf16x4_t
bf164
;
};
U64BF164_ADDR
addr
;
addr
.
bf164_a
=
p_dst
;
// interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164
cur_v
;
cur_v
.
u64
=
*
addr
.
u64_a
;
U64BF164
new_v_union
;
uint64_t
old_v
,
new_v
;
do
{
// old 64 bits
old_v
=
cur_v
.
u64
;
// Add elementwise in bf16
new_v_union
.
bf164
=
add_bf16x4_t
(
cur_v
.
bf164
,
x
);
new_v
=
new_v_union
.
u64
;
// Attempt the 64-bit CAS
cur_v
.
u64
=
atomicCAS
(
addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
fp8x4_t
>
(
fp8x4_t
*
p_dst
,
const
fp8x4_t
&
x
)
{
union
U32FP84_ADDR
{
uint32_t
*
u32_a
;
fp8x4_t
*
fp84_a
;
};
union
U32FP84
{
uint32_t
u32
;
fp8x4_t
fp84
;
};
U32FP84_ADDR
dword_addr
;
U32FP84
cur_v
;
U32FP84
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
fp84_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
fp84
=
add_fp8x4_t
(
cur_v
.
fp84
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf8x4_t
>
(
bf8x4_t
*
p_dst
,
const
bf8x4_t
&
x
)
{
union
U32BF84_ADDR
{
uint32_t
*
u32_a
;
bf8x4_t
*
bf84_a
;
};
union
U32BF84
{
uint32_t
u32
;
bf8x4_t
bf84
;
};
U32BF84_ADDR
dword_addr
;
U32BF84
cur_v
;
U32BF84
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
bf84_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
bf84
=
add_bf8x4_t
(
cur_v
.
bf84
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
//
// Atomic add for fp8x8_t
//
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
fp8x8_t
>
(
fp8x8_t
*
p_dst
,
fp8x8_t
const
&
x
)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union
U64FP88_ADDR
{
uint64_t
*
u64_a
;
// pointer to 64-bit integer
fp8x8_t
*
fp88_a
;
// pointer to fp8x8_t
};
union
U64FP88
{
uint64_t
u64
;
fp8x8_t
fp88
;
};
U64FP88_ADDR
dword_addr
;
U64FP88
cur_v
;
U64FP88
new_v_union
;
uint64_t
old_v
,
new_v
;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr
.
fp88_a
=
p_dst
;
// Initial read of 64 bits from memory
cur_v
.
u64
=
*
dword_addr
.
u64_a
;
do
{
old_v
=
cur_v
.
u64
;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union
.
fp88
=
add_fp8x8_t
(
cur_v
.
fp88
,
x
);
new_v
=
new_v_union
.
u64
;
// Attempt 64-bit CAS
cur_v
.
u64
=
atomicCAS
(
dword_addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
//
// Atomic add for bf8x8_t
//
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf8x8_t
>
(
bf8x8_t
*
p_dst
,
bf8x8_t
const
&
x
)
{
union
U64BF88_ADDR
{
uint64_t
*
u64_a
;
bf8x8_t
*
bf88_a
;
};
union
U64BF88
{
uint64_t
u64
;
bf8x8_t
bf88
;
};
U64BF88_ADDR
dword_addr
;
U64BF88
cur_v
;
U64BF88
new_v_union
;
uint64_t
old_v
,
new_v
;
dword_addr
.
bf88_a
=
p_dst
;
// Read the original 64 bits
cur_v
.
u64
=
*
dword_addr
.
u64_a
;
do
{
old_v
=
cur_v
.
u64
;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union
.
bf88
=
add_bf8x8_t
(
cur_v
.
bf88
,
x
);
new_v
=
new_v_union
.
u64
;
// 64-bit CAS loop
cur_v
.
u64
=
atomicCAS
(
dword_addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
atomic_add_g
(
T
*
p_dst
,
const
thread_buffer
<
T
,
N
>&
x
)
CK_TILE_DEVICE
void
atomic_add_g
(
T
*
p_dst
,
const
thread_buffer
<
T
,
N
>&
x
)
{
{
...
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
...
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(
std
::
is_same
<
T
,
uint32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
uint32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
)),
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
))
||
"wrong! not implemented"
);
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
4
||
N
==
8
||
N
==
16
)),
"The granularity of the thread buffer is unsupported on the hardware!"
);
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
...
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
...
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
}
}
else
if
constexpr
(
N
==
4
)
else
if
constexpr
(
N
==
4
)
{
{
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x2_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x4_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
)
+
1
,
}
x
.
template
get_as
<
bf16x2_t
>()[
I1
]);
else
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x4_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf16x4_t
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
fp8_t
>::
value
)
{
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x4_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x4_t
>()[
I0
]);
}
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x8_t
>()[
I0
]);
}
if
constexpr
(
N
==
16
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x8_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
fp8x8_t
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf8_t
>::
value
)
{
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x4_t
>()[
I0
]);
}
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x8_t
>()[
I0
]);
}
if
constexpr
(
N
==
16
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x8_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf8x8_t
>()[
I1
]);
}
}
}
}
}
}
...
...
include/ck_tile/core/config.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx9__
#define __gfx9__
#endif
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#define __gfx94__
#endif
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
...
@@ -144,6 +144,10 @@
...
@@ -144,6 +144,10 @@
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
#endif
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
#endif
// buffer atomic add: floating point
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
...
@@ -230,3 +234,15 @@
...
@@ -230,3 +234,15 @@
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifdef CK_TILE_USE_OCP_FP8
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_OCP_FP8 1
#else // for GPU code
#define CK_TILE_USE_OCP_FP8 0
#endif
include/ck_tile/core/numeric/float8.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
@@ -14,6 +14,12 @@
...
@@ -14,6 +14,12 @@
#pragma once
#pragma once
#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_TILE_FP8_CVT_DEVICE 1
#else
#define CK_TILE_FP8_CVT_DEVICE 0
#endif
namespace
ck_tile
{
namespace
ck_tile
{
// fp8 rounding modes
// fp8 rounding modes
...
@@ -25,15 +31,26 @@ enum class fp8_rounding_mode
...
@@ -25,15 +31,26 @@ enum class fp8_rounding_mode
stochastic
stochastic
};
};
/**
* \brief FP8 interpretation used in conversion algorithms
*/
enum
class
fp8_interpretation
{
E4M3_OCP
=
0
,
// OCP FP8 E4M3
E5M2_OCP
=
1
,
// OCP BF8 E5M2
E4M3_FNUZ
=
2
,
// FNUZ FP8 E4M3
E5M2_FNUZ
=
3
,
// FNUZ BF8 E5M2
};
/*
/*
* ______________
NANOO
_________________ | ______________
IEEE
________________
* ______________
FNUZ
_________________ | ______________
OCP
________________
* e4m3 e5m2 | e4m3 e5m2
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111
(448)
s.00000.11
(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111
s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
...
@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t
...
@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t
{
{
static
constexpr
int
exponent
=
4
;
static
constexpr
int
exponent
=
4
;
static
constexpr
int
mantissa
=
3
;
static
constexpr
int
mantissa
=
3
;
#if
defined(__gfx94__)
#if
CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
static
constexpr
int
bias
=
7
;
// OCP
#else
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
static
constexpr
int
bias
=
8
;
// FNUZ
#endif
#endif
using
raw_type
=
uint8_t
;
using
raw_type
=
uint8_t
;
raw_type
data
;
raw_type
data
;
...
@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t
...
@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t
{
{
static
constexpr
int
exponent
=
5
;
static
constexpr
int
exponent
=
5
;
static
constexpr
int
mantissa
=
2
;
static
constexpr
int
mantissa
=
2
;
#if
defined(__gfx94__)
#if
CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
static
constexpr
int
bias
=
1
5
;
// OCP
#else
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
//
IEEE
static
constexpr
int
bias
=
1
6
;
//
FNUZ
#endif
#endif
using
raw_type
=
uint8_t
;
using
raw_type
=
uint8_t
;
raw_type
data
;
raw_type
data
;
...
@@ -183,501 +200,727 @@ struct native_t<bf8_t>
...
@@ -183,501 +200,727 @@ struct native_t<bf8_t>
};
};
#else
#else
using
fp8_t
=
_BitInt
(
8
);
using
fp8_t
=
_BitInt
(
8
);
using
fp8_raw_t
=
uint8_t
;
using
fp8_raw_t
=
uint8_t
;
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_raw_t
=
uint8_t
;
using
bf8_raw_t
=
uint8_t
;
#endif
#endif
// below is sw fp8 conversion, not utilizing hw instruction
template
<
typename
T
>
namespace
impl
{
struct
numeric_traits
;
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
template
<
>
CK_TILE_HOST_DEVICE
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
struct
numeric_traits
<
fp8_t
>
{
{
// fp8/bf8 exponent/mantissa layout
using
bitwise_type
=
fp8_raw_t
;
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
7
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_OCP
;
#else
static
constexpr
int
bias
=
8
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_FNUZ
;
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
};
// original type exponent/mantissa layout
template
<
>
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
struct
numeric_traits
<
bf8_t
>
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
{
using
bitwise_type
=
bf8_raw_t
;
int
exponent
,
bias
;
static
constexpr
int
exp
=
5
;
uint32_t
head
,
mantissa
,
sign
;
static
constexpr
int
mant
=
2
;
// nan code is same for float and half
#if CK_TILE_USE_OCP_FP8
#if CK_TILE_USE_CUSTOM_DATA_TYPE
static
constexpr
int
bias
=
15
;
constexpr
Y
nan_code
=
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_OCP
;
numeric
<
Y
>::
quiet_NaN
();
// __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
#else
#else
constexpr
Y
nan_code
=
0x80
;
static
constexpr
int
bias
=
16
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_FNUZ
;
#endif
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
};
// below is sw fp8 conversion, not utilizing hw instruction
namespace
impl
{
template
<
typename
SrcT
,
typename
DstT
,
bool
clip
=
true
,
bool
stoch
=
false
>
CK_TILE_HOST_DEVICE
DstT
run_cast_to_f8
(
SrcT
src
,
unsigned
int
rng
=
0
)
{
static_assert
(
std
::
is_same
<
DstT
,
fp8_t
>::
value
||
std
::
is_same
<
DstT
,
bf8_t
>::
value
,
"DstT type must be fp8 or bf8."
);
constexpr
uint32_t
nan_mask
=
numeric_traits
<
X
>::
nan_mask
;
constexpr
bool
is_half
=
std
::
is_same
<
SrcT
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
SrcT
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
// convert to bitwise
// fp8/bf8 type exponent/mantissa layout
using
T_bitwise
=
typename
numeric_traits
<
X
>::
bitwise_type
;
constexpr
int
DstT_exp
=
numeric_traits
<
DstT
>::
exp
;
// exponent width of the destination type
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
constexpr
int
DstT_mant
=
numeric_traits
<
DstT
>::
mant
;
// mantissa width of the destination type
constexpr
bool
is_fnuz
=
(
numeric_traits
<
DstT
>::
f8_interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
numeric_traits
<
DstT
>::
f8_interpret
==
fp8_interpretation
::
E5M2_FNUZ
);
// unpack the input, depends on datatype
constexpr
int
SrcT_exp
=
numeric_traits
<
SrcT
>::
exp
;
head
=
x_bitwise
&
numeric_traits
<
X
>::
head_mask
;
constexpr
int
SrcT_mant
=
numeric_traits
<
SrcT
>::
mant
;
mantissa
=
x_bitwise
&
numeric_traits
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
numeric_traits
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
numeric_traits
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
using
SrcT_bitwise
=
typename
numeric_traits
<
SrcT
>::
bitwise_type
;
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
SrcT_bitwise
src_bitwise
=
bit_cast
<
SrcT_bitwise
>
(
src
);
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
if
constexpr
(
negative_zero_nan
)
unsigned
long
long
head
,
mantissa
;
int
exponent
,
bias
;
unsigned
int
sign
;
unsigned
long
long
fInf
,
abs_mask
;
head
=
src_bitwise
&
numeric_traits
<
SrcT
>::
head_mask
;
mantissa
=
src_bitwise
&
numeric_traits
<
SrcT
>::
mant_mask
;
exponent
=
(
head
>>
SrcT_mant
)
&
numeric_traits
<
SrcT
>::
exp_mask
;
sign
=
head
>>
(
SrcT_exp
+
SrcT_mant
);
bias
=
numeric_traits
<
SrcT
>::
bias
;
fInf
=
numeric_traits
<
SrcT
>::
Inf
;
abs_mask
=
numeric_traits
<
SrcT
>::
abs_mask
;
unsigned
int
signed_inf
=
0
;
unsigned
int
nan
=
0
;
if
constexpr
(
is_fnuz
)
{
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
signed_inf
=
clip
?
((
sign
<<
7
)
+
0x7f
)
:
0x80
;
return
nan_code
;
nan
=
0x80
;
}
}
else
else
{
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
if
constexpr
(
DstT_exp
==
4
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
{
// e4m3
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7e
:
0x7f
);
}
else
{
// e5m2
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7b
:
0x7c
);
}
nan
=
(
sign
<<
7
)
+
0x7f
;
}
// Max values
unsigned
long
long
ifmax
=
0
;
if
constexpr
(
is_float
)
{
if
constexpr
(
DstT_exp
==
5
)
{
ifmax
=
0x47600000
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x43700000
;
}
else
{
ifmax
=
0x43E00000
;
}
}
}
else
if
constexpr
(
is_half
)
{
if
constexpr
(
DstT_exp
==
5
)
{
ifmax
=
0x7B00
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x5B80
;
}
else
{
ifmax
=
0x5F00
;
}
}
}
}
// check if x is 0.0
// Deal with inf and NaNs
if
(
x_bitwise
==
0
)
if
((
src_bitwise
&
fInf
)
==
fInf
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
0
));
{
if
constexpr
(
is_fnuz
)
return
signed_inf
;
return
mantissa
!=
0
?
nan
:
signed_inf
;
}
if
((
src_bitwise
&
abs_mask
)
>
ifmax
)
{
return
signed_inf
;
}
if
(
src_bitwise
==
0
)
{
return
0
;
}
// First need to check if it is normal or denorm as there is a difference of
implict 1
// First need to check if it is normal or denorm as there is a difference of
// Then need to adjust the exponent to align with the F8 exponent,
in the meanwhile, shift
//
implicit 1
Then need to adjust the exponent to align with the F8 exponent,
// The mantissa. Then for stochastic rounding, add rng
to mantissa and truncate. And for
//
in the meanwhile, shift
The mantissa. Then for stochastic rounding, add rng
// RNE, no need to add rng. Then probably
need to check whether there is carry and adjust
//
to mantissa and truncate. And for
RNE, no need to add rng. Then probably
// exponent and mantissa again
3
//
need to check whether there is carry and adjust
exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
// bits
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
const
int
f8_bias
=
(
1
<<
(
DstT_exp
-
1
))
-
1
+
(
is_fnuz
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
//
out
_exponent is the converted f8 exponent with bias encoding
//
f8
_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out
_exponent
,
exponent_diff
;
int
act_exponent
,
f8
_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
exponent bias 16. It means that there are some numbers in fp16 denormal but they
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
In this case, the fp16 mantissa should be shift left by 1 */
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out
_denormal_act_exponent
-
exponent_diff
=
f8
_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
}
else
else
{
// fp32/fp16 is normal with implicit 1
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out
_denormal_act_exponent
)
if
(
act_exponent
<=
f8
_denormal_act_exponent
)
{
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range.
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
actual exponent is -7, it is actually larger due to the implic
i
t 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out
_denormal_act_exponent
-
act_exponent
;
exponent_diff
=
f8
_denormal_act_exponent
-
act_exponent
;
}
}
else
else
{
// both fp32/fp16 and f8 are in normal range
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// for this case, act_exponent could be larger. Just
// act_exponent could be larger. Just
that it does not need shift mantissa
//
that it does not need shift mantissa
}
}
mantissa
+=
(
1
<<
in
_mant
);
// Add the implicit 1 into mantissa
mantissa
+=
(
1
ull
<<
SrcT
_mant
);
// Add the implicit 1 into mantissa
}
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
bool
midpoint
=
(
mantissa
&
((
1ull
<<
(
SrcT_mant
-
DstT_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
(
1ull
<<
(
SrcT_mant
-
DstT_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
shift right as shift right could rip off some residual part and make something not midpoint look
done before we shift right as shift right could rip off some residual part and
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
make something not midpoint look like midpoint. For example, the fp16 number
midpoint, but after shift right by 4 bits, it would look like midpoint. */
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
bool
implicit_one
=
mantissa
&
(
1ull
<<
SrcT_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
// if there is no implicit 1, it means the f8 is denormal and need to adjust
out_exponent
=
// to denorm exponent
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
// Now we have the exponent and mantissa adjusted
unsigned
long
long
drop_mask
=
(
1ull
<<
(
SrcT_mant
-
DstT_mant
))
-
1
;
bool
odd
=
bool
odd
=
mantissa
&
mantissa
&
(
1ull
<<
(
SrcT_mant
-
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
DstT_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1ull
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
// Now we deal with overflow
if
(
out
_exponent
==
0
)
if
(
f8
_exponent
==
0
)
{
{
if
((
1
<<
in
_mant
)
&
mantissa
)
if
((
1
ull
<<
SrcT
_mant
)
&
mantissa
)
{
{
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
}
}
else
else
{
{
if
((
1
<<
(
in
_mant
+
1
))
&
mantissa
)
if
((
1
ull
<<
(
SrcT
_mant
+
1
))
&
mantissa
)
{
{
mantissa
>>=
1
;
mantissa
>>=
1
;
out_exponent
++
;
f8_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
}
}
mantissa
>>=
(
in
_mant
-
out
_mant
);
mantissa
>>=
(
SrcT
_mant
-
DstT
_mant
);
if
(
out_exponent
>
max_exp
)
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
DstT_exp
)
-
1
;
if
(
f8_exponent
>
max_exp
)
{
{
if
(
clip
)
if
constexpr
(
clip
)
{
{
mantissa
=
(
1
<<
out
_mant
)
-
1
;
mantissa
=
(
1
<<
DstT
_mant
)
-
1
;
out
_exponent
=
max_exp
;
f8
_exponent
=
max_exp
;
}
}
else
else
{
{
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
signed_inf
))
;
return
signed_inf
;
}
}
}
}
// check if x is 0.0 or -0.0
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
if
(
out_exponent
==
0
&&
mantissa
==
0
)
return
is_fnuz
?
0
:
(
sign
<<
7
);
return
__builtin_bit_cast
(
mantissa
&=
(
1
<<
DstT_mant
)
-
1
;
Y
,
static_cast
<
uint8_t
>
(
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
))));
return
(
sign
<<
7
)
|
(
f8_exponent
<<
DstT_mant
)
|
mantissa
;
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
((
sign
<<
(
out_exp
+
out_mant
))
|
(
out_exponent
<<
out_mant
)
|
mantissa
));
}
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
template
<
typename
SrcT
,
typename
DstT
,
bool
clip
=
true
>
CK_TILE_HOST_DEVICE
Y
run_cast_from_f8
(
X
x
)
CK_TILE_HOST_DEVICE
DstT
run_cast_from_f8
(
SrcT
x
)
{
{
// fp8/bf8 exponent/mantissa layout
static_assert
(
std
::
is_same
<
SrcT
,
fp8_t
>::
value
||
std
::
is_same
<
SrcT
,
bf8_t
>::
value
,
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
"SrcT type must be fp8 or bf8."
);
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
constexpr
int
SrcT_exp
=
numeric_traits
<
SrcT
>::
exp
;
constexpr
int
SrcT_mant
=
numeric_traits
<
SrcT
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
bool
is_fnuz
=
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
(
numeric_traits
<
SrcT
>::
f8_interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
(
numeric_traits
<
SrcT
>::
f8_interpret
==
fp8_interpretation
::
E5M2_FNUZ
);
uint8_t
x_raw
=
__builtin_bit_cast
(
uint8_t
,
x
);
constexpr
bool
is_half
=
std
::
is_same
<
DstT
,
half_t
>::
value
;
// prepare the codes
constexpr
bool
is_float
=
std
::
is_same
<
DstT
,
float
>::
value
;
constexpr
uint8_t
nan_code
=
0x80
;
static_assert
(
is_half
||
is_float
,
"DstT type must be half_t or float."
);
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
numeric_traits
<
Y
>::
bitwise_type
;
// destination type exponent/mantissa layout
constexpr
int
DstT_exp
=
numeric_traits
<
DstT
>::
exp
;
// exponent width of the destination type
constexpr
T_bitwise
Inf_bitwise
=
numeric_traits
<
Y
>::
Inf
;
constexpr
int
DstT_mant
=
numeric_traits
<
DstT
>::
mant
;
// mantissa width of the destination type
constexpr
T_bitwise
NegInf_bitwise
=
numeric_traits
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
numeric_traits
<
Y
>::
NaN
;
constexpr
DstT
fInf
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
Inf
);
constexpr
T_bitwise
Neg0_bitwise
=
numeric_traits
<
Y
>::
Neg0
;
constexpr
DstT
fNegInf
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
NegInf
);
constexpr
DstT
fNaN
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
NaN
);
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
constexpr
DstT
fNeg0
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
Neg0
);
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
DstT
fmax
{
0
},
fmin
{
0
};
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// Max number in e5m2 57344
if
constexpr
(
is_half
)
// check if x is 0.0
{
if
(
x_raw
==
0
)
fmax
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0x7B00
));
return
static_cast
<
Y
>
(
0
);
fmin
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0xFB00
));
}
// unpack the input
else
if
constexpr
(
is_float
)
uint32_t
sign
=
x_raw
>>
(
in_exp
+
in_mant
);
{
uint32_t
mantissa
=
x_raw
&
((
1
<<
in_mant
)
-
1
);
fmax
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0x47600000
));
int
exponent
=
(
x_raw
&
0x7F
)
>>
in_mant
;
fmin
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0xC7600000
));
}
constexpr
int
exp_low_cutoff
=
if
(
x
==
0
)
(
1
<<
(
out_exp
-
1
))
-
(
1
<<
(
in_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
{
T_bitwise
retval
;
return
0
;
}
if
constexpr
(
negative_zero_nan
)
unsigned
long
long
sign
=
x
>>
7
;
unsigned
long
long
mantissa
=
x
&
((
1
<<
SrcT_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
SrcT_mant
;
if
constexpr
(
is_fnuz
)
{
{
if
(
x_raw
==
nan_code
)
if
(
x
==
0x80
)
return
NaN
;
{
return
fNaN
;
}
}
}
else
else
{
{
if
(
x_raw
==
nan_code
)
if
(
x
==
0x80
)
return
Neg0
;
{
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
fNeg0
;
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
constexpr
(
SrcT_exp
==
4
)
{
// e4m3
if
((
x
&
0x7F
)
==
0x7F
)
{
return
fNaN
;
}
}
else
if
((
x
&
0x7C
)
==
0x7C
)
{
// e5m2
if
((
x
&
0x3
)
==
0
)
{
if
constexpr
(
clip
)
{
return
sign
?
fmin
:
fmax
;
}
return
sign
?
fNegInf
:
fInf
;
}
return
fNaN
;
}
}
}
if
((
numeric_traits
<
Y
>::
mant
==
10
)
&&
(
numeric_traits
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
typename
numeric_traits
<
DstT
>::
bitwise_type
retval
;
if
constexpr
(
SrcT_exp
==
5
&&
is_half
&&
!
is_fnuz
)
{
{
retval
=
x_raw
;
retval
=
x
<<
8
;
retval
<<=
8
;
return
bit_cast
<
DstT
>
(
retval
);
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
const
int
exp_low_cutoff
=
(
1
<<
(
DstT_exp
-
1
))
-
(
1
<<
(
SrcT_exp
-
1
))
+
1
-
(
is_fnuz
?
1
:
0
);
// subnormal input
// subnormal input
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
SrcT_mant
);
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
mantissa
<<=
sh
;
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in
_mant
)
-
1
);
mantissa
&=
((
1
ull
<<
SrcT
_mant
)
-
1
);
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
out
_mant
-
in
_mant
;
mantissa
<<=
DstT
_mant
-
SrcT
_mant
;
// subnormal output (occurs when
T=
half, we=5,
negative_zero_nan
=true)
// subnormal output (occurs when
DstT is
half
_t
, we=5,
is_fnuz
=true)
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
{
{
mantissa
|=
1
<<
out
_mant
;
mantissa
|=
1
<<
DstT
_mant
;
mantissa
>>=
1
-
exponent
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
exponent
=
0
;
}
}
retval
=
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
retval
=
(
sign
<<
(
DstT_exp
+
DstT_mant
))
|
(
exponent
<<
DstT_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatypes
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run
_cast
_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
bit
_cast
<
DstT
>
(
retval
);
}
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_
from
_f8
(
X
x
)
CK_TILE_HOST_DEVICE
Y
cast_
to
_f8
(
X
x
,
uint32_t
rng
)
{
{
// check datatype
return
bit_cast
<
Y
>
(
run_cast_to_f8
<
X
,
Y
,
clip
,
stoch
>
(
x
,
rng
));
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
}
}
}
// namespace impl
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_sr_raw
(
float
x
)
#if CK_TILE_FP8_CVT_DEVICE
/**
* @brief Cast float to fp8/bf8 using device conversion instructions
*/
template
<
fp8_interpretation
interpret
,
bool
saturate
,
bool
stochastic_rounding
=
false
>
CK_TILE_DEVICE
uint8_t
cast_to_f8_from_f32
(
float
v
,
unsigned
int
rng
=
0
)
{
{
constexpr
int
seed
=
42
;
uint8_t
i8data
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
union
{
{
float
fval
;
float
fval
;
u
int32_
t
i32val
;
u
nsigned
in
t
i32val
;
u
int8_t
i8val
[
4
];
// not endian independent
u
nsigned
char
i8val
[
4
];
//
NOTE:
not endian independent
}
val
;
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_sr_raw
(
float
x
)
unsigned
int
ival
=
0
;
{
val
.
fval
=
v
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
if
constexpr
(
saturate
)
#if defined(__gfx94__)
union
{
{
float
fval
;
if
constexpr
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
uint32_t
i32val
;
{
uint8_t
i8val
[
4
];
// not endian independent
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
}
val
;
{
/// propagate NAN/INF, no clipping
val
.
fval
=
x
;
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
uint32_t
ival
=
0
;
}
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
}
val
.
i32val
=
ival
;
else
if
constexpr
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
return
val
.
i8val
[
0
];
// little endian
{
// OCP type
#else
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
constexpr
bool
negative_zero_nan
=
true
;
{
/// propagate NAN/INF, no clipping
constexpr
bool
clip
=
true
;
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
448.0
,
-
448.0
);
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
}
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
}
bf8_t
,
else
negative_zero_nan
,
{
clip
,
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
{
/// propagate NAN/INF, no clipping
#endif
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
57344.0
,
-
57344.0
);
}
}
}
if
constexpr
(
stochastic_rounding
)
{
ival
=
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
?
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
)
:
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
// little endian
}
else
{
// RNE CVT
ival
=
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
?
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
)
:
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
}
return
i8data
;
}
}
#endif // CK_TILE_FP8_CVT_DEVICE
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_rtn_raw
(
float
x
)
}
// namespace impl
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with stochastic
* rounding.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may
* involve clipping and uses a pseudo-random number generator for the stochastic rounding.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template
<
typename
SrcT
,
typename
DstT
>
CK_TILE_HOST_DEVICE
typename
numeric_traits
<
DstT
>::
bitwise_type
float_to_fp8_sr_raw
(
SrcT
x
)
{
{
#if defined(__gfx94__)
constexpr
bool
clip
=
true
;
float
max_fp8
=
240.0
f
;
constexpr
int
seed
=
42
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
uint32_t
rng
=
prand_generator_t
<
SrcT
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
union
#if CK_TILE_FP8_CVT_DEVICE
{
return
impl
::
cast_to_f8_from_f32
<
numeric_traits
<
DstT
>::
f8_interpret
,
clip
,
true
>
(
x
,
rng
);
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
return
bit_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
constexpr
bool
clip
=
true
;
impl
::
cast_to_f8
<
SrcT
,
DstT
,
clip
,
true
>
(
x
,
rng
));
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
#endif
}
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_rtn_raw
(
float
x
)
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with rounding to
* nearest even.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may involve clipping.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template
<
typename
SrcT
,
typename
DstT
>
CK_TILE_HOST_DEVICE
typename
numeric_traits
<
DstT
>::
bitwise_type
float_to_fp8_rtn_raw
(
SrcT
x
)
{
{
#if defined(__gfx94__)
constexpr
bool
clip
=
true
;
union
#if CK_TILE_FP8_CVT_DEVICE
{
return
impl
::
cast_to_f8_from_f32
<
numeric_traits
<
DstT
>::
f8_interpret
,
clip
,
false
>
(
x
,
0
);
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
return
bit_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
constexpr
bool
clip
=
true
;
impl
::
cast_to_f8
<
SrcT
,
DstT
,
clip
,
false
>
(
x
,
0
));
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
#endif
}
}
// clang-format off
template
<
fp8_rounding_mode
rounding
>
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_raw
(
float
x
,
constant
<
rounding
>
)
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_raw
(
float
x
,
constant
<
rounding
>
)
{
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_fp8_rtn_raw
(
x
);
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_fp8_sr_raw
(
x
);
{
else
return
fp8_raw_t
{
0
};
return
float_to_fp8_rtn_raw
<
float
,
fp8_t
>
(
x
);
}
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
{
return
float_to_fp8_sr_raw
<
float
,
fp8_t
>
(
x
);
}
else
{
return
fp8_raw_t
{
0
};
}
}
}
template
<
fp8_rounding_mode
rounding
>
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_raw
(
float
x
,
constant
<
rounding
>
)
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_raw
(
float
x
,
constant
<
rounding
>
)
{
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_bf8_rtn_raw
(
x
);
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_bf8_sr_raw
(
x
);
{
else
return
bf8_raw_t
{
0
};
return
float_to_fp8_rtn_raw
<
float
,
bf8_t
>
(
x
);
}
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
{
return
float_to_fp8_sr_raw
<
float
,
bf8_t
>
(
x
);
}
else
{
return
bf8_raw_t
{
0
};
}
}
}
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
{
{
#if
defined(__gfx94__)
#if
CK_TILE_FP8_CVT_DEVICE
float
fval
;
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
return
fval
;
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
run_cast_from_f8
<
fp8_t
,
float
>
(
bit_cast
<
fp8_t
>
(
x
));
return
impl
::
cast_from_f8
<
fp8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
fp8_t
>
(
x
));
#endif
#endif
}
}
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
{
{
#if
defined(__gfx94__)
#if
CK_TILE_FP8_CVT_DEVICE
float
fval
;
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
return
fval
;
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
run_cast_from_f8
<
bf8_t
,
float
>
(
bit_cast
<
bf8_t
>
(
x
));
return
impl
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
bf8_t
>
(
x
));
#endif
#endif
}
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
fp8_t
float_to_fp8
(
float
x
,
constant
<
rounding
>
=
{})
CK_TILE_HOST_DEVICE
fp8_t
float_to_fp8
(
float
x
,
constant
<
rounding
>
=
{})
{
{
return
bit_cast
<
fp8_t
>
(
float_to_fp8_raw
(
x
,
constant
<
rounding
>
{}));
return
bit_cast
<
fp8_t
>
(
float_to_fp8_raw
(
x
,
constant
<
rounding
>
{}));
}
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
bf8_t
float_to_bf8
(
float
x
,
constant
<
rounding
>
=
{})
CK_TILE_HOST_DEVICE
bf8_t
float_to_bf8
(
float
x
,
constant
<
rounding
>
=
{})
{
{
return
bit_cast
<
bf8_t
>
(
float_to_bf8_raw
(
x
,
constant
<
rounding
>
{}));
return
bit_cast
<
bf8_t
>
(
float_to_bf8_raw
(
x
,
constant
<
rounding
>
{}));
}
}
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
// clang-format on
template
<
class
T
>
struct
numeric
;
template
<
typename
T
>
struct
numeric_traits
;
#if CK_TILE_USE_OCP_FP8
template
<
>
template
<
>
struct
numeric
_traits
<
fp8_t
>
struct
numeric
<
fp8_t
>
{
{
static
constexpr
int
exp
=
4
;
// minimum finite value, or minimum positive normal value
static
constexpr
int
mant
=
3
;
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
min
()
#if defined(__gfx94__)
{
static
constexpr
int
bias
=
8
;
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x08
));
// 0b00001000 = 2^-6
#else
}
static
constexpr
int
bias
=
7
;
#endif
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
lowest
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xfe
));
// 0b11111110 = -448
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
max
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7e
));
// 0b01111110 = 448
}
// difference between 1.0 and next representable f8 value (1.125)
// returns fp8_t(0.125)
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
epsilon
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x20
));
// 0.125
}
// rounding error (0.0625)
// half of epsilon
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
round_error
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x18
));
// 0.0625
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
quiet_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7F
));
// 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
signaling_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xFF
));
// 0b11111111
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
denorm_min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
zero
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0
));
}
};
};
template
<
>
template
<
>
struct
numeric
_traits
<
bf8_t
>
struct
numeric
<
bf8_t
>
{
{
static
constexpr
int
exp
=
5
;
// minimum finite value, or minimum positive normalized value for float
static
constexpr
int
mant
=
2
;
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
min
()
#if defined(__gfx94__)
{
static
constexpr
int
bias
=
16
;
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x04
));
// 0b00000100 = 2^-14
#else
}
static
constexpr
int
bias
=
15
;
// IEEE
#endif
};
template
<
class
T
>
// minumum finite value
struct
numeric
;
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
lowest
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xfb
));
// 0b11111011 = -57344
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
max
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7b
));
// 0b01111011 = 57344
}
// difference between 1.0 and next representable bf8 value (1.25)
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
epsilon
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x34
));
// 0.25
}
// rounding error (0.125)
// half of epsilon
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
round_error
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x30
));
// 0.125
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
infinity
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7c
));
// 0b01111100
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
quiet_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7F
));
// 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
signaling_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xFF
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
denorm_min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
zero
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
};
#else
template
<
>
template
<
>
struct
numeric
<
fp8_t
>
struct
numeric
<
fp8_t
>
{
{
...
@@ -811,6 +1054,7 @@ struct numeric<bf8_t>
...
@@ -811,6 +1054,7 @@ struct numeric<bf8_t>
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
}
};
};
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
fp8_t
)
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
fp8_t
)
...
@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
...
@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
#endif
#endif
// math
// math
CK_TILE_HOST_DEVICE
template
<
typename
T
>
fp8_t
abs
(
const
fp8_t
&
x
)
CK_TILE_HOST_DEVICE
T
abs
(
const
T
&
x
)
{
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
fp8_raw_t
>
(
x
)
&
0x7f
));
static_assert
(
std
::
is_same_v
<
T
,
fp8_t
>
||
std
::
is_same_v
<
T
,
bf8_t
>
,
"Only fp8_t and bf8_t are supported"
);
return
bit_cast
<
T
>
(
static_cast
<
uint8_t
>
(
bit_cast
<
uint8_t
>
(
x
)
&
numeric_traits
<
T
>::
abs_mask
));
}
}
CK_TILE_HOST_DEVICE
CK_TILE_HOST_DEVICE
bool
isnan
(
const
fp8_t
&
x
)
bool
isnan
(
const
fp8_t
&
x
)
{
{
uint8_t
xx
=
bit_cast
<
fp8_raw_t
>
(
x
);
uint8_t
xx
=
bit_cast
<
fp8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
}
#if CK_TILE_USE_OCP_FP8
return
(
xx
&
0x7f
)
==
0x7f
;
#else
return
xx
==
0x80
;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
...
@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }
...
@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
CK_TILE_DEVICE
fp8_t
log
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
fp8_t
log
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
CK_TILE_HOST_DEVICE
bf8_t
abs
(
const
bf8_t
&
x
)
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
bf8_raw_t
>
(
x
)
&
0x7f
));
}
CK_TILE_HOST_DEVICE
CK_TILE_HOST_DEVICE
bool
isnan
(
const
bf8_t
&
x
)
bool
isnan
(
const
bf8_t
&
x
)
{
{
uint8_t
xx
=
bit_cast
<
bf8_raw_t
>
(
x
);
uint8_t
xx
=
bit_cast
<
bf8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
#if CK_TILE_USE_OCP_FP8
return
(
xx
&
0x7f
)
>
0x7c
;
#else
return
xx
==
0x80
;
#endif
}
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
...
@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }
...
@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
CK_TILE_DEVICE
bf8_t
log
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
bf8_t
log
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/numeric/half.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
@@ -236,10 +236,11 @@ struct numeric_traits<half_t>
...
@@ -236,10 +236,11 @@ struct numeric_traits<half_t>
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint16_t
abs_mask
=
0x7FFF
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint16_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint16_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
uint16_t
NaN
=
0x7C01
;
static
constexpr
uint16_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
using
bitwise_type
=
uint16_t
;
};
};
...
...
include/ck_tile/core/numeric/numeric.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -89,6 +89,7 @@ struct numeric_traits<float>
...
@@ -89,6 +89,7 @@ struct numeric_traits<float>
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
abs_mask
=
0x7FFFFFFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
...
...
include/ck_tile/core/numeric/pk_int4.hpp
0 → 100644
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#include "ck_tile/core/numeric/int8.hpp"
#pragma once
namespace
ck_tile
{
// Packed 2xint4
struct
pk_int4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_int4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_int4_t
(
type
init
)
:
data
{
init
}
{}
};
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
pk_int4_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
min
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
lowest
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
max
()
{
constexpr
uint8_t
val
=
0b01110111
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
epsilon
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
round_error
()
{
return
1
;
// not used
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
infinity
()
{
return
1
;
// not used
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
quiet_NaN
()
{
return
1
;
// not used
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
signaling_NaN
()
{
return
1
;
// not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
denorm_min
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
zero
()
{
return
0
;
}
};
CK_TILE_HOST_DEVICE
fp32x2_t
pk_int4_t_to_fp32x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t
res
=
{
x_h
,
x_l
};
#elif
fp32x2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
CK_TILE_HOST_DEVICE
fp16x2_t
pk_int4_t_to_halfx2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#elif
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
pk_add_f16
(
bit_cast
<
fp16x2_t
>
(
lo
),
bit_cast
<
fp16x2_t
>
(
SUB
));
}
CK_TILE_HOST_DEVICE
bf16x2_t
pk_int4_t_to_bfloat16x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_h
),
type_convert
<
bf16_t
>
(
x_l
)};
#elif
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_l
),
type_convert
<
bf16_t
>
(
x_h
)};
#endif
return
res
;
}
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
...
@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
#endif
#endif
CK_TILE_HOST
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
CK_TILE_DEVICE
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
x
),
"v"
(
y
));
return
c
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host.hpp
View file @
f23a2e2a
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/device_memory.hpp"
...
@@ -20,6 +21,7 @@
...
@@ -20,6 +21,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
...
...
include/ck_tile/host/check_err.hpp
View file @
f23a2e2a
...
@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
...
@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
{
using
F8
=
ck_tile
::
fp8_t
;
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
double
compute_error
=
0
;
...
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
...
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
...
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
...
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
{
using
F8
=
ck_tile
::
fp8_t
;
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
...
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
...
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
...
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
...
include/ck_tile/host/concat.hpp
0 → 100644
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
template
<
typename
T
>
struct
IsCharArray
:
std
::
false_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
char
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
const
char
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
char
(
&
)[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
const
char
(
&
)[
N
]
>
:
std
::
true_type
{
};
template
<
typename
...
Ts
>
inline
constexpr
bool
AllConvertibleToStringView
=
((
std
::
is_convertible_v
<
Ts
,
std
::
string_view
>
||
IsCharArray
<
Ts
>::
value
||
std
::
is_same_v
<
Ts
,
char
>
)
&&
...);
template
<
typename
...
Ts
>
[[
nodiscard
]]
auto
concat
(
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<!
AllConvertibleToStringView
<
Ts
...
>
,
std
::
string
>
{
using
::
operator
<<
;
thread_local
std
::
ostringstream
oss
;
oss
.
str
(
""
);
(
oss
<<
...
<<
xs
);
return
oss
.
str
();
}
template
<
std
::
size_t
N
>
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
char
(
&
)[
N
])
noexcept
{
return
N
;
}
template
<
std
::
size_t
N
>
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
(
&
)[
N
])
noexcept
{
return
N
;
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
*
s
)
noexcept
{
const
char
*
end
=
s
;
while
(
*
end
++
!=
0
)
{}
return
end
-
s
-
1
;
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
&
)
noexcept
{
return
1
;
}
[[
nodiscard
]]
inline
std
::
size_t
getSize
(
const
std
::
string
&
s
)
noexcept
{
return
s
.
size
();
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
std
::
string_view
&
s
)
noexcept
{
return
s
.
size
();
}
template
<
typename
...
Ts
>
auto
concatInto
(
std
::
string
&
result
,
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
Ts
...
>
,
void
>
{
const
std
::
size_t
space
=
(
1
+
...
+
getSize
(
xs
));
result
.
reserve
(
result
.
size
()
+
space
);
((
result
+=
xs
),
...);
}
template
<
typename
...
Ts
>
[[
nodiscard
]]
auto
concat
(
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
Ts
...
>
,
std
::
string
>
{
std
::
string
result
;
concatInto
(
result
,
xs
...);
return
result
;
}
// Function for types convertible to std::string_view
template
<
typename
Sep
,
typename
First
,
typename
...
Rest
>
[[
nodiscard
]]
auto
concat
(
Sep
sep
,
const
First
&
first
,
const
Rest
&
...
rest
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
First
,
Rest
...
>
,
std
::
string
>
{
std
::
string
result
;
result
+=
first
;
((
result
+=
sep
,
result
+=
rest
),
...);
return
result
;
}
// Function for other types
template
<
typename
Sep
,
typename
First
,
typename
...
Rest
>
[[
nodiscard
]]
auto
concat
(
Sep
sep
,
const
First
&
first
,
const
Rest
&
...
rest
)
->
std
::
enable_if_t
<!
AllConvertibleToStringView
<
First
,
Rest
...
>
,
std
::
string
>
{
using
::
operator
<<
;
thread_local
std
::
ostringstream
oss
;
oss
.
str
(
""
);
oss
<<
first
;
((
oss
<<
sep
<<
rest
),
...);
return
oss
.
str
();
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_batched_transpose.hpp
0 → 100644
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
Type
>
CK_TILE_HOST
void
reference_batched_transpose
(
const
HostTensor
<
Type
>&
x
,
HostTensor
<
Type
>&
y
,
std
::
string
layout_in
=
"NCHW"
,
std
::
string
layout_out
=
"NHWC"
)
{
const
int
N
=
x
.
mDesc
.
get_lengths
()[
0
];
auto
f
=
[
&
](
auto
batch
)
{
if
(
layout_in
==
"NCHW"
&&
layout_out
==
"NHWC"
)
{
const
int
C
=
x
.
mDesc
.
get_lengths
()[
1
];
const
int
H
=
x
.
mDesc
.
get_lengths
()[
2
];
const
int
W
=
x
.
mDesc
.
get_lengths
()[
3
];
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
Type
v_x
=
x
(
batch
,
c
,
h
,
w
);
y
(
batch
,
h
,
w
,
c
)
=
v_x
;
}
}
}
}
else
if
(
layout_in
==
"NHWC"
&&
layout_out
==
"NCHW"
)
{
const
int
H
=
x
.
mDesc
.
get_lengths
()[
1
];
const
int
W
=
x
.
mDesc
.
get_lengths
()[
2
];
const
int
C
=
x
.
mDesc
.
get_lengths
()[
3
];
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
Type
v_x
=
x
(
batch
,
h
,
w
,
c
);
y
(
batch
,
c
,
h
,
w
)
=
v_x
;
}
}
}
}
};
make_ParallelTensorFunctor
(
f
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
f23a2e2a
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
A
[
a_index
])
*
ck_tile
::
type_convert
<
AccDataType
>
(
B
[
b_index
]);
}
}
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
C
[
c_index
]
=
ck_tile
::
type_convert
<
CDataType
>
(
acc
)
;
}
}
}
}
...
...
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
f23a2e2a
...
@@ -14,12 +14,15 @@ namespace ck_tile {
...
@@ -14,12 +14,15 @@ namespace ck_tile {
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
const
HostTensor
<
WeightType
>&
weights
,
const
HostTensor
<
IndexType
>&
local_expert_mask
,
HostTensor
<
IndexType
>&
p_sorted_token_ids
,
HostTensor
<
IndexType
>&
p_sorted_token_ids
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
index_t
&
unit_cnt
,
index_t
&
unit_cnt
,
const
index_t
experts
,
const
index_t
experts
,
const
index_t
unit_size
)
const
index_t
unit_size
,
bool
local_expert_masking
,
bool
skip_experts_with_zero_token
=
true
)
{
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
...
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
...
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
// count number of unit-size slices in this expert
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
// count the tokens used in this expert
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
// TODO: above 2 buffer seems duplicated
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
{
{
...
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
...
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType
*
out_tokens
=
p_sorted_token_ids
.
data
();
IndexType
*
out_tokens
=
p_sorted_token_ids
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
int
curr_expert_id
=
0
;
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
{
if
(
local_expert_masking
)
{
if
(
local_expert_mask
(
e
)
==
0
)
continue
;
}
if
(
skip_experts_with_zero_token
)
{
if
(
expert_slice_idxs
[
e
]
==
0
)
{
curr_expert_id
++
;
continue
;
}
}
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
memcpy
(
out_weights
,
...
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
...
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
{
out_expert_id
[
s
]
=
e
;
out_expert_id
[
s
]
=
curr_expert_id
;
unit_cnt
++
;
unit_cnt
++
;
}
}
out_expert_id
+=
expert_slices
[
e
];
out_expert_id
+=
expert_slices
[
e
];
curr_expert_id
++
;
}
}
unit_cnt
*=
unit_size
;
unit_cnt
*=
unit_size
;
return
;
return
;
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
f23a2e2a
...
@@ -10,3 +10,4 @@
...
@@ -10,3 +10,4 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/batched_transpose.hpp
0 → 100644
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp
0 → 100644
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
BatchedTransposeHostArgs
{
const
void
*
p_input
;
void
*
p_output
;
index_t
batch
;
index_t
height
;
index_t
width
;
// index_t dim_blocks;
index_t
dim_stride
;
index_t
dim_block_h
;
index_t
dim_block_w
;
};
template
<
typename
Pipeline_
>
struct
BatchedTransposeKernel
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
remove_cvref_t
<
typename
Pipeline
::
Problem
>
;
using
Type
=
typename
Problem
::
InputType
;
struct
BatchedTransposeKargs
{
const
void
*
p_input
;
void
*
p_output
;
index_t
batch
;
index_t
height
;
index_t
width
;
index_t
dim_stride
;
};
using
Kargs
=
BatchedTransposeKargs
;
using
Hargs
=
BatchedTransposeHostArgs
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
size_t
grid_size_x
=
(
h
.
width
+
h
.
dim_block_w
-
1
)
/
h
.
dim_block_w
;
size_t
grid_size_y
=
(
h
.
height
+
h
.
dim_block_h
-
1
)
/
h
.
dim_block_h
;
size_t
grid_size_z
=
h
.
batch
;
return
dim3
(
grid_size_x
,
grid_size_y
,
grid_size_z
);
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_input
=
h
.
p_input
;
k
.
p_output
=
h
.
p_output
;
k
.
batch
=
h
.
batch
;
k
.
height
=
h
.
height
;
k
.
width
=
h
.
width
;
k
.
dim_stride
=
h
.
dim_stride
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
kBlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
ck_tile
::
index_t
kMPerThread
=
Problem
::
kMPerThread
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
kNPerThread
;
static_assert
(
kMPerThread
==
1
&&
kNPerThread
==
1
);
const
auto
iDim
=
blockIdx
.
z
;
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
Type
*>
(
kargs
.
p_input
)
+
iDim
*
kargs
.
dim_stride
,
make_tuple
(
kargs
.
height
,
kargs
.
width
),
make_tuple
(
kargs
.
width
,
1
),
number
<
kNPerThread
>
{},
// TODO thread load value
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
const
auto
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kMPerBlock
);
const
auto
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kNPerBlock
);
const
auto
y_n_m
=
[
&
]()
{
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
Type
*>
(
kargs
.
p_output
)
+
iDim
*
kargs
.
dim_stride
,
make_tuple
(
kargs
.
width
,
kargs
.
height
),
make_tuple
(
kargs
.
height
,
1
),
number
<
kMPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kMPerBlock
>
{}),
sequence
<
kPadN
,
kPadM
>
{});
}();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
static_cast
<
ck_tile
::
index_t
>
(
iM
*
kMPerBlock
),
static_cast
<
ck_tile
::
index_t
>
(
iN
*
kNPerBlock
)});
auto
y_block_window
=
make_tile_window
(
y_n_m
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kMPerBlock
>
{}),
{
static_cast
<
ck_tile
::
index_t
>
(
iN
*
kNPerBlock
),
static_cast
<
ck_tile
::
index_t
>
(
iM
*
kMPerBlock
)});
Pipeline
{}(
x_block_window
,
y_block_window
);
}
};
}
// namespace ck_tile
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment